diff --git a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb
index eb4cfdbc..fd8fc1d0 100644
--- a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb
+++ b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb
@@ -41,7 +41,9 @@
"id": "e13eb331",
"metadata": {},
"outputs": [],
- "source": ["%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"]
+ "source": [
+ "%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"
+ ]
},
{
"cell_type": "markdown",
@@ -65,11 +67,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "686c4d2f",
"metadata": {},
"outputs": [],
- "source": ["import getpass\nimport os\n\n# del os.environ['NVIDIA_API_KEY'] ## delete key and reset\nif os.environ.get(\"NVIDIA_API_KEY\", \"\").startswith(\"nvapi-\"):\n print(\"Valid NVIDIA_API_KEY already in environment. Delete to reset\")\nelse:\n nvapi_key = getpass.getpass(\"NVAPI Key (starts with nvapi-): \")\n assert nvapi_key.startswith(\"nvapi-\"), f\"{nvapi_key[:5]}... is not a valid key\"\n os.environ[\"NVIDIA_API_KEY\"] = nvapi_key"]
+ "source": [
+ "import getpass\n",
+ "import os\n",
+ "\n",
+ "# del os.environ['NVIDIA_API_KEY'] ## delete key and reset\n",
+ "if os.environ.get(\"NVIDIA_API_KEY\", \"\").startswith(\"nvapi-\"):\n",
+ " print(\"Valid NVIDIA_API_KEY already in environment. Delete to reset\")\n",
+ "else:\n",
+ " nvapi_key = getpass.getpass(\"NVAPI Key (starts with nvapi-): \")\n",
+ " assert nvapi_key.startswith(\"nvapi-\"), f\"{nvapi_key[:5]}... is not a valid key\"\n",
+ " os.environ[\"NVIDIA_API_KEY\"] = nvapi_key"
+ ]
},
{
"cell_type": "markdown",
@@ -91,7 +104,14 @@
"outputId": "e9c4cc72-8db6-414b-d8e9-95de93fc5db4"
},
"outputs": [],
- "source": ["## Core LC Chat Interface\nfrom langchain_nvidia_ai_endpoints import ChatNVIDIA\n\nllm = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\")\nresult = llm.invoke(\"Write a ballad about LangChain.\")\nprint(result.content)"]
+ "source": [
+ "## Core LC Chat Interface\n",
+ "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
+ "\n",
+ "llm = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\")\n",
+ "result = llm.invoke(\"Write a ballad about LangChain.\")\n",
+ "print(result.content)"
+ ]
},
{
"cell_type": "markdown",
@@ -110,7 +130,12 @@
"id": "49838930",
"metadata": {},
"outputs": [],
- "source": ["from langchain_nvidia_ai_endpoints import ChatNVIDIA\n\n# connect to an embedding NIM running at localhost:8000, specifying a specific model\nllm = ChatNVIDIA(base_url=\"http://localhost:8000/v1\", model=\"meta/llama3-8b-instruct\")"]
+ "source": [
+ "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
+ "\n",
+ "# connect to an embedding NIM running at localhost:8000, specifying a specific model\n",
+ "llm = ChatNVIDIA(base_url=\"http://localhost:8000/v1\", model=\"meta/llama3-8b-instruct\")"
+ ]
},
{
"cell_type": "markdown",
@@ -128,7 +153,11 @@
"id": "01fa5095-be72-47b0-8247-e9fac799435d",
"metadata": {},
"outputs": [],
- "source": ["print(llm.batch([\"What's 2*3?\", \"What's 2*6?\"]))\n# Or via the async API\n# await llm.abatch([\"What's 2*3?\", \"What's 2*6?\"])"]
+ "source": [
+ "print(llm.batch([\"What's 2*3?\", \"What's 2*6?\"]))\n",
+ "# Or via the async API\n",
+ "# await llm.abatch([\"What's 2*3?\", \"What's 2*6?\"])"
+ ]
},
{
"cell_type": "code",
@@ -136,7 +165,11 @@
"id": "75189ac6-e13f-414f-9064-075c77d6e754",
"metadata": {},
"outputs": [],
- "source": ["for chunk in llm.stream(\"How far can a seagull fly in one day?\"):\n # Show the token separations\n print(chunk.content, end=\"|\")"]
+ "source": [
+ "for chunk in llm.stream(\"How far can a seagull fly in one day?\"):\n",
+ " # Show the token separations\n",
+ " print(chunk.content, end=\"|\")"
+ ]
},
{
"cell_type": "code",
@@ -144,7 +177,12 @@
"id": "8a9a4122-7a10-40c0-a979-82a769ce7f6a",
"metadata": {},
"outputs": [],
- "source": ["async for chunk in llm.astream(\n \"How long does it take for monarch butterflies to migrate?\"\n):\n print(chunk.content, end=\"|\")"]
+ "source": [
+ "async for chunk in llm.astream(\n",
+ " \"How long does it take for monarch butterflies to migrate?\"\n",
+ "):\n",
+ " print(chunk.content, end=\"|\")"
+ ]
},
{
"cell_type": "markdown",
@@ -166,7 +204,10 @@
"id": "5b8a312d-38e9-4528-843e-59451bdadbac",
"metadata": {},
"outputs": [],
- "source": ["ChatNVIDIA.get_available_models()\n# llm.get_available_models()"]
+ "source": [
+ "ChatNVIDIA.get_available_models()\n",
+ "# llm.get_available_models()"
+ ]
},
{
"cell_type": "markdown",
@@ -206,7 +247,19 @@
"id": "f5f7aee8-e90c-4d5a-ac97-0dd3d45c3f4c",
"metadata": {},
"outputs": [],
- "source": ["from langchain_core.output_parsers import StrOutputParser\nfrom langchain_core.prompts import ChatPromptTemplate\nfrom langchain_nvidia_ai_endpoints import ChatNVIDIA\n\nprompt = ChatPromptTemplate.from_messages(\n [(\"system\", \"You are a helpful AI assistant named Fred.\"), (\"user\", \"{input}\")]\n)\nchain = prompt | ChatNVIDIA(model=\"meta/llama3-8b-instruct\") | StrOutputParser()\n\nfor txt in chain.stream({\"input\": \"What's your name?\"}):\n print(txt, end=\"\")"]
+ "source": [
+ "from langchain_core.output_parsers import StrOutputParser\n",
+ "from langchain_core.prompts import ChatPromptTemplate\n",
+ "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
+ "\n",
+ "prompt = ChatPromptTemplate.from_messages(\n",
+ " [(\"system\", \"You are a helpful AI assistant named Fred.\"), (\"user\", \"{input}\")]\n",
+ ")\n",
+ "chain = prompt | ChatNVIDIA(model=\"meta/llama3-8b-instruct\") | StrOutputParser()\n",
+ "\n",
+ "for txt in chain.stream({\"input\": \"What's your name?\"}):\n",
+ " print(txt, end=\"\")"
+ ]
},
{
"cell_type": "markdown",
@@ -224,7 +277,21 @@
"id": "49aa569b-5f33-47b3-9edc-df58313eb038",
"metadata": {},
"outputs": [],
- "source": ["prompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are an expert coding AI. Respond only in valid python; no narration whatsoever.\",\n ),\n (\"user\", \"{input}\"),\n ]\n)\nchain = prompt | ChatNVIDIA(model=\"meta/codellama-70b\") | StrOutputParser()\n\nfor txt in chain.stream({\"input\": \"How do I solve this fizz buzz problem?\"}):\n print(txt, end=\"\")"]
+ "source": [
+ "prompt = ChatPromptTemplate.from_messages(\n",
+ " [\n",
+ " (\n",
+ " \"system\",\n",
+ " \"You are an expert coding AI. Respond only in valid python; no narration whatsoever.\",\n",
+ " ),\n",
+ " (\"user\", \"{input}\"),\n",
+ " ]\n",
+ ")\n",
+ "chain = prompt | ChatNVIDIA(model=\"meta/codellama-70b\") | StrOutputParser()\n",
+ "\n",
+ "for txt in chain.stream({\"input\": \"How do I solve this fizz buzz problem?\"}):\n",
+ " print(txt, end=\"\")"
+ ]
},
{
"cell_type": "markdown",
@@ -244,7 +311,15 @@
"id": "26625437-1695-440f-b792-b85e6add9a90",
"metadata": {},
"outputs": [],
- "source": ["import IPython\nimport requests\n\nimage_url = \"https://www.nvidia.com/content/dam/en-zz/Solutions/research/ai-playground/nvidia-picasso-3c33-p@2x.jpg\" ## Large Image\nimage_content = requests.get(image_url).content\n\nIPython.display.Image(image_content)"]
+ "source": [
+ "import IPython\n",
+ "import requests\n",
+ "\n",
+ "image_url = \"https://www.nvidia.com/content/dam/en-zz/Solutions/research/ai-playground/nvidia-picasso-3c33-p@2x.jpg\" ## Large Image\n",
+ "image_content = requests.get(image_url).content\n",
+ "\n",
+ "IPython.display.Image(image_content)"
+ ]
},
{
"cell_type": "code",
@@ -252,7 +327,11 @@
"id": "dfbbe57c-27a5-4cbb-b967-19c4e7d29fd0",
"metadata": {},
"outputs": [],
- "source": ["from langchain_nvidia_ai_endpoints import ChatNVIDIA\n\nllm = ChatNVIDIA(model=\"nvidia/neva-22b\")"]
+ "source": [
+ "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
+ "\n",
+ "llm = ChatNVIDIA(model=\"nvidia/neva-22b\")"
+ ]
},
{
"cell_type": "markdown",
@@ -268,7 +347,20 @@
"id": "432ea2a2-4d39-43f8-a236-041294171f14",
"metadata": {},
"outputs": [],
- "source": ["from langchain_core.messages import HumanMessage\n\nllm.invoke(\n [\n HumanMessage(\n content=[\n {\"type\": \"text\", \"text\": \"Describe this image:\"},\n {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n ]\n )\n ]\n)"]
+ "source": [
+ "from langchain_core.messages import HumanMessage\n",
+ "\n",
+ "llm.invoke(\n",
+ " [\n",
+ " HumanMessage(\n",
+ " content=[\n",
+ " {\"type\": \"text\", \"text\": \"Describe this image:\"},\n",
+ " {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n",
+ " ]\n",
+ " )\n",
+ " ]\n",
+ ")"
+ ]
},
{
"cell_type": "markdown",
@@ -366,7 +458,15 @@
"id": "c58f1dd0",
"metadata": {},
"outputs": [],
- "source": ["import IPython\nimport requests\n\nimage_url = \"https://picsum.photos/seed/kitten/300/200\"\nimage_content = requests.get(image_url).content\n\nIPython.display.Image(image_content)"]
+ "source": [
+ "import IPython\n",
+ "import requests\n",
+ "\n",
+ "image_url = \"https://picsum.photos/seed/kitten/300/200\"\n",
+ "image_content = requests.get(image_url).content\n",
+ "\n",
+ "IPython.display.Image(image_content)"
+ ]
},
{
"cell_type": "code",
@@ -374,7 +474,28 @@
"id": "8c721629-42eb-4006-bf68-0296f7925ebc",
"metadata": {},
"outputs": [],
- "source": ["import base64\n\nfrom langchain_core.messages import HumanMessage\n\n## Works for simpler images. For larger images, see actual implementation\nb64_string = base64.b64encode(image_content).decode(\"utf-8\")\n\nllm.invoke(\n [\n HumanMessage(\n content=[\n {\"type\": \"text\", \"text\": \"Describe this image:\"},\n {\n \"type\": \"image_url\",\n \"image_url\": {\"url\": f\"data:image/png;base64,{b64_string}\"},\n },\n ]\n )\n ]\n)"]
+ "source": [
+ "import base64\n",
+ "\n",
+ "from langchain_core.messages import HumanMessage\n",
+ "\n",
+ "## Works for simpler images. For larger images, see actual implementation\n",
+ "b64_string = base64.b64encode(image_content).decode(\"utf-8\")\n",
+ "\n",
+ "llm.invoke(\n",
+ " [\n",
+ " HumanMessage(\n",
+ " content=[\n",
+ " {\"type\": \"text\", \"text\": \"Describe this image:\"},\n",
+ " {\n",
+ " \"type\": \"image_url\",\n",
+ " \"image_url\": {\"url\": f\"data:image/png;base64,{b64_string}\"},\n",
+ " },\n",
+ " ]\n",
+ " )\n",
+ " ]\n",
+ ")"
+ ]
},
{
"cell_type": "markdown",
@@ -392,7 +513,10 @@
"id": "00c06a9a-497b-4192-a842-b075e27401aa",
"metadata": {},
"outputs": [],
- "source": ["base64_with_mime_type = f\"data:image/png;base64,{b64_string}\"\nllm.invoke(f'What\\'s in this image?\\n
')"]
+ "source": [
+ "base64_with_mime_type = f\"data:image/png;base64,{b64_string}\"\n",
+ "llm.invoke(f'What\\'s in this image?\\n
')"
+ ]
},
{
"cell_type": "markdown",
@@ -420,7 +544,9 @@
"id": "082ccb21-91e1-4e71-a9ba-4bff1e89f105",
"metadata": {},
"outputs": [],
- "source": ["%pip install --upgrade --quiet langchain"]
+ "source": [
+ "%pip install --upgrade --quiet langchain"
+ ]
},
{
"cell_type": "code",
@@ -430,7 +556,41 @@
"id": "fd2c6bc1"
},
"outputs": [],
- "source": ["from langchain_core.chat_history import InMemoryChatMessageHistory\nfrom langchain_core.runnables.history import RunnableWithMessageHistory\n\n# store is a dictionary that maps session IDs to their corresponding chat histories.\nstore = {} # memory is maintained outside the chain\n\n\n# A function that returns the chat history for a given session ID.\ndef get_session_history(session_id: str) -> InMemoryChatMessageHistory:\n if session_id not in store:\n store[session_id] = InMemoryChatMessageHistory()\n return store[session_id]\n\n\nchat = ChatNVIDIA(\n model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n temperature=0.1,\n max_tokens=100,\n top_p=1.0,\n)\n\n# Define a RunnableConfig object, with a `configurable` key. session_id determines thread\nconfig = {\"configurable\": {\"session_id\": \"1\"}}\n\nconversation = RunnableWithMessageHistory(\n chat,\n get_session_history,\n)\n\nconversation.invoke(\n \"Hi I'm Srijan Dubey.\", # input or query\n config=config,\n)"]
+ "source": [
+ "from langchain_core.chat_history import InMemoryChatMessageHistory\n",
+ "from langchain_core.runnables.history import RunnableWithMessageHistory\n",
+ "\n",
+ "# store is a dictionary that maps session IDs to their corresponding chat histories.\n",
+ "store = {} # memory is maintained outside the chain\n",
+ "\n",
+ "\n",
+ "# A function that returns the chat history for a given session ID.\n",
+ "def get_session_history(session_id: str) -> InMemoryChatMessageHistory:\n",
+ " if session_id not in store:\n",
+ " store[session_id] = InMemoryChatMessageHistory()\n",
+ " return store[session_id]\n",
+ "\n",
+ "\n",
+ "chat = ChatNVIDIA(\n",
+ " model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n",
+ " temperature=0.1,\n",
+ " max_tokens=100,\n",
+ " top_p=1.0,\n",
+ ")\n",
+ "\n",
+ "# Define a RunnableConfig object, with a `configurable` key. session_id determines thread\n",
+ "config = {\"configurable\": {\"session_id\": \"1\"}}\n",
+ "\n",
+ "conversation = RunnableWithMessageHistory(\n",
+ " chat,\n",
+ " get_session_history,\n",
+ ")\n",
+ "\n",
+ "conversation.invoke(\n",
+ " \"Hi I'm Srijan Dubey.\", # input or query\n",
+ " config=config,\n",
+ ")"
+ ]
},
{
"cell_type": "code",
@@ -445,7 +605,12 @@
"outputId": "79acc89d-a820-4f2c-bac2-afe99da95580"
},
"outputs": [],
- "source": ["conversation.invoke(\n \"I'm doing well! Just having a conversation with an AI.\",\n config=config,\n)"]
+ "source": [
+ "conversation.invoke(\n",
+ " \"I'm doing well! Just having a conversation with an AI.\",\n",
+ " config=config,\n",
+ ")"
+ ]
},
{
"cell_type": "code",
@@ -460,7 +625,12 @@
"outputId": "a1714513-a8fd-4d14-f974-233e39d5c4f5"
},
"outputs": [],
- "source": ["conversation.invoke(\n \"Tell me about yourself.\",\n config=config,\n)"]
+ "source": [
+ "conversation.invoke(\n",
+ " \"Tell me about yourself.\",\n",
+ " config=config,\n",
+ ")"
+ ]
},
{
"cell_type": "markdown",
@@ -488,7 +658,10 @@
"id": "e36c8911",
"metadata": {},
"outputs": [],
- "source": ["tool_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_tools]\ntool_models"]
+ "source": [
+ "tool_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_tools]\n",
+ "tool_models"
+ ]
},
{
"cell_type": "markdown",
@@ -504,7 +677,21 @@
"id": "bd54f174",
"metadata": {},
"outputs": [],
- "source": ["from pydantic import Field\nfrom langchain_core.tools import tool\n\n@tool\ndef get_current_weather(\n location: str = Field(..., description=\"The location to get the weather for.\")\n):\n \"\"\"Get the current weather for a location.\"\"\"\n ...\n\nllm = ChatNVIDIA(model=tool_models[0].id).bind_tools(tools=[get_current_weather])\nresponse = llm.invoke(\"What is the weather in Boston?\")\nresponse.tool_calls"]
+ "source": [
+ "from pydantic import Field\n",
+ "from langchain_core.tools import tool\n",
+ "\n",
+ "@tool\n",
+ "def get_current_weather(\n",
+ " location: str = Field(..., description=\"The location to get the weather for.\")\n",
+ "):\n",
+ " \"\"\"Get the current weather for a location.\"\"\"\n",
+ " ...\n",
+ "\n",
+ "llm = ChatNVIDIA(model=tool_models[0].id).bind_tools(tools=[get_current_weather])\n",
+ "response = llm.invoke(\"What is the weather in Boston?\")\n",
+ "response.tool_calls"
+ ]
},
{
"cell_type": "markdown",
@@ -542,7 +729,11 @@
"id": "0515f558",
"metadata": {},
"outputs": [],
- "source": ["from langchain_nvidia_ai_endpoints import ChatNVIDIA\nstructured_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_structured_output]\nstructured_models"]
+ "source": [
+ "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
+ "structured_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_structured_output]\n",
+ "structured_models"
+ ]
},
{
"cell_type": "markdown",
@@ -558,7 +749,17 @@
"id": "482c37e8",
"metadata": {},
"outputs": [],
- "source": ["from pydantic import BaseModel, Field\n\nclass Person(BaseModel):\n first_name: str = Field(..., description=\"The person's first name.\")\n last_name: str = Field(..., description=\"The person's last name.\")\n\nllm = ChatNVIDIA(model=structured_models[0].id).with_structured_output(Person)\nresponse = llm.invoke(\"Who is Michael Jeffrey Jordon?\")\nresponse"]
+ "source": [
+ "from pydantic import BaseModel, Field\n",
+ "\n",
+ "class Person(BaseModel):\n",
+ " first_name: str = Field(..., description=\"The person's first name.\")\n",
+ " last_name: str = Field(..., description=\"The person's last name.\")\n",
+ "\n",
+ "llm = ChatNVIDIA(model=structured_models[0].id).with_structured_output(Person)\n",
+ "response = llm.invoke(\"Who is Michael Jeffrey Jordon?\")\n",
+ "response"
+ ]
},
{
"cell_type": "markdown",
@@ -574,7 +775,24 @@
"id": "7f802912",
"metadata": {},
"outputs": [],
- "source": ["from enum import Enum\n\nclass Choices(Enum):\n A = \"A\"\n B = \"B\"\n C = \"C\"\n\nllm = ChatNVIDIA(model=structured_models[2].id).with_structured_output(Choices)\nresponse = llm.invoke(\"\"\"\n What does 1+1 equal?\n A. -100\n B. 2\n C. doorstop\n \"\"\"\n)\nresponse"]
+ "source": [
+ "from enum import Enum\n",
+ "\n",
+ "class Choices(Enum):\n",
+ " A = \"A\"\n",
+ " B = \"B\"\n",
+ " C = \"C\"\n",
+ "\n",
+ "llm = ChatNVIDIA(model=structured_models[2].id).with_structured_output(Choices)\n",
+ "response = llm.invoke(\"\"\"\n",
+ " What does 1+1 equal?\n",
+ " A. -100\n",
+ " B. 2\n",
+ " C. doorstop\n",
+ " \"\"\"\n",
+ ")\n",
+ "response"
+ ]
},
{
"cell_type": "code",
@@ -582,7 +800,111 @@
"id": "02b7ef29",
"metadata": {},
"outputs": [],
- "source": ["model = structured_models[3].id\nllm = ChatNVIDIA(model=model).with_structured_output(Choices)\nprint(model)\nresponse = llm.invoke(\"\"\"\n What does 1+1 equal?\n A. -100\n B. 2\n C. doorstop\n \"\"\"\n)\nresponse"]
+ "source": [
+ "model = structured_models[3].id\n",
+ "llm = ChatNVIDIA(model=model).with_structured_output(Choices)\n",
+ "print(model)\n",
+ "response = llm.invoke(\"\"\"\n",
+ " What does 1+1 equal?\n",
+ " A. -100\n",
+ " B. 2\n",
+ " C. doorstop\n",
+ " \"\"\"\n",
+ ")\n",
+ "response"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b5476b50",
+ "metadata": {},
+ "source": [
+ "### JSON Schema"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6fb3e5cb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "json_schema = {\n",
+ " \"title\": \"joke\",\n",
+ " \"description\": \"Joke to tell user.\",\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"setup\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The setup of the joke\",\n",
+ " },\n",
+ " \"punchline\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The punchline to the joke\",\n",
+ " },\n",
+ " \"rating\": {\n",
+ " \"type\": \"integer\",\n",
+ " \"description\": \"How funny the joke is, from 1 to 10\",\n",
+ " \"default\": None,\n",
+ " },\n",
+ " },\n",
+ " \"required\": [\"setup\", \"punchline\"],\n",
+ "}\n",
+ "llm = ChatNVIDIA(model=structured_models[0].id)\n",
+ "structured_llm = llm.with_structured_output(json_schema)\n",
+ "\n",
+ "structured_llm.invoke(\"Tell me a joke about cats\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c8893046",
+ "metadata": {},
+ "source": [
+ "## [JSON mode](https://platform.openai.com/docs/guides/text-generation/json-mode)\n",
+ "\n",
+ "Constrain the model to only generate valid JSON. Note that you must include a system message with instructions to use JSON for this mode to work.\n",
+ "\n",
+ "Only works with certain models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a1c955cc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain_core.messages import HumanMessage, SystemMessage\n",
+ "\n",
+ "chat = ChatNVIDIA(model=tool_models[0].id).bind(\n",
+ " response_format={\"type\": \"json_object\"}\n",
+ ")\n",
+ "\n",
+ "output = chat.invoke(\n",
+ " [\n",
+ " SystemMessage(\n",
+ " content=\"Extract the 'name' and 'origin' of any companies mentioned in the following statement. Return a JSON list.\"\n",
+ " ),\n",
+ " HumanMessage(\n",
+ " content=\"Google was founded in the USA, while Deepmind was founded in the UK\"\n",
+ " ),\n",
+ " ]\n",
+ ")\n",
+ "print(output.content)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2a884df2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "json.loads(output.content)"
+ ]
}
],
"metadata": {
@@ -590,7 +912,7 @@
"provenance": []
},
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -604,7 +926,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.13"
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
index 63ee21c5..6741a1a4 100644
--- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
+++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
@@ -48,8 +48,11 @@
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
-from langchain_core.utils.function_calling import convert_to_openai_tool
-from langchain_core.utils.pydantic import is_basemodel_subclass
+from langchain_core.utils.function_calling import (
+ convert_to_openai_function,
+ convert_to_openai_tool,
+)
+from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from pydantic import BaseModel, Field, PrivateAttr
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
@@ -237,6 +240,28 @@ def _process_for_vlm(
return inputs, extra_headers
+def _convert_to_openai_response_format(
+ schema: Union[Dict[str, Any], Type],
+) -> Union[Dict, TypeBaseModel]:
+ if isinstance(schema, type) and is_basemodel_subclass(schema):
+ return schema
+
+ if (
+ isinstance(schema, dict)
+ and "json_schema" in schema
+ and schema.get("type") == "json_schema"
+ ):
+ response_format = schema
+ elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
+ response_format = {"type": "json_schema", "json_schema": schema}
+ else:
+ function = convert_to_openai_function(schema)
+ function["schema"] = function.pop("parameters")
+ response_format = {"type": "json_schema", "json_schema": function}
+
+ return response_format
+
+
_DEFAULT_MODEL_NAME: str = "meta/llama3-8b-instruct"
@@ -649,6 +674,7 @@ def with_structured_output( # type: ignore
self,
schema: Union[Dict, Type],
*,
+ method: Literal["json_mode", "json_schema"] = "json_schema",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
@@ -658,6 +684,13 @@ def with_structured_output( # type: ignore
Args:
schema (Union[Dict, Type]): The schema to bind to the model.
include_raw (bool): Always False. Passing True raises an error.
+ method: The method for steering model generation, one of:
+ - "json_schema":
+ Uses Structured Output API for supported models.
+ - "json_mode":
+ Uses JSON mode. Note that if using JSON mode then you
+ must include instructions for formatting the output into the
+ desired schema into the model call:
**kwargs: Additional keyword arguments.
Notes:
@@ -773,13 +806,6 @@ class Choices(enum.Enum):
For more, see https://python.langchain.com/docs/how_to/structured_output/
""" # noqa: E501
- if "method" in kwargs:
- warnings.warn(
- "The 'method' parameter is unnecessary and is ignored. "
- "The appropriate method will be chosen automatically depending "
- "on the type of schema provided."
- )
-
if kwargs.get("strict", True) is not True:
warnings.warn(
"Structured output always follows strict validation. "
@@ -796,6 +822,8 @@ class Choices(enum.Enum):
"being None when the LLM produces an incomplete response."
)
+ is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
+
# check if the model supports structured output, warn if it does not
known_good = False
# todo: we need to store model: Model in this class
@@ -814,11 +842,32 @@ class Choices(enum.Enum):
f"Model '{self.model}' is not known to support structured output. "
"Your output may fail at inference time."
)
+ output_parser: BaseOutputParser # create a common type
+
+ if method == "json_mode":
+ llm = self.bind(response_format={"type": "json_object"})
+ output_parser = (
+ PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
+ if is_pydantic_schema
+ else JsonOutputParser()
+ )
+ elif method == "json_schema":
+ response_format = _convert_to_openai_response_format(schema)
+ llm = self.bind(response_format=response_format)
+ output_parser = (
+ PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
+ if is_pydantic_schema
+ else JsonOutputParser()
+ )
+ else:
+ raise ValueError(
+ f"Unrecognized method argument. Expected one of 'json_scheme' or "
+ f"'json_mode'. Received: '{method}'"
+ )
if isinstance(schema, dict):
- output_parser: BaseOutputParser = JsonOutputParser()
- nvext_param: Dict[str, Any] = {"guided_json": schema}
-
+ output_parser = JsonOutputParser()
+ llm = self.bind(nvext={"guided_json": schema})
elif issubclass(schema, enum.Enum):
# langchain's EnumOutputParser is not in langchain_core
# and doesn't support streaming. this is a simple implementation
@@ -845,7 +894,7 @@ def parse(self, response: str) -> Any:
"Use StrEnum or ensure all member values are strings."
)
output_parser = EnumOutputParser(enum=schema)
- nvext_param = {"guided_choice": choices}
+ llm = self.bind(nvext={"guided_choice": choices})
elif is_basemodel_subclass(schema):
# PydanticOutputParser does not support streaming. what we do
@@ -867,7 +916,7 @@ def parse_result(
json_schema = schema.model_json_schema()
else:
json_schema = schema.schema()
- nvext_param = {"guided_json": json_schema}
+ llm = self.bind(nvext={"guided_json": json_schema})
else:
raise ValueError(
@@ -875,4 +924,4 @@ def parse_result(
"representing a JSON schema, or an Enum."
)
- return super().bind(nvext=nvext_param) | output_parser
+ return llm | output_parser
diff --git a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py
index aaee110a..1f69a7cd 100644
--- a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py
+++ b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py
@@ -11,6 +11,7 @@
BaseMessage,
BaseMessageChunk,
)
+from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from pydantic import Field
@@ -787,3 +788,35 @@ def magic(
assert response_in < baseline_in * tolerance
assert response_out < baseline_out * tolerance
assert response_total < baseline_total * tolerance
+
+
+def test_json_mode(tool_model: str) -> None:
+ llm = ChatNVIDIA(model=tool_model).bind(response_format={"type": "json_object"})
+ response = llm.invoke(
+ "Return this as json: {'a': 1}",
+ )
+ assert isinstance(response.content, str)
+ assert json.loads(response.content) == {"a": 1}
+
+ # Test streaming
+ full: Optional[Union[BaseMessage, ChatPromptTemplate]] = None
+ for chunk in llm.stream("Return this as json: {'a': 1}"):
+ full = chunk if full is None else full + chunk
+ assert isinstance(full, AIMessageChunk)
+ assert isinstance(full.content, str)
+ assert json.loads(full.content) == {"a": 1}
+
+
+async def test_json_mode_async(tool_model: str) -> None:
+ llm = ChatNVIDIA(model=tool_model).bind(response_format={"type": "json_object"})
+ response = await llm.ainvoke("Return this as json: {'a': 1}")
+ assert isinstance(response.content, str)
+ assert json.loads(response.content) == {"a": 1}
+
+ # Test streaming
+ full: Optional[Union[BaseMessage, ChatPromptTemplate]] = None
+ async for chunk in llm.astream("Return this as json: {'a': 1}"):
+ full = chunk if full is None else full + chunk
+ assert isinstance(full, AIMessageChunk)
+ assert isinstance(full.content, str)
+ assert json.loads(full.content) == {"a": 1}
diff --git a/libs/ai-endpoints/tests/integration_tests/test_structured_output.py b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py
index 1f059e75..66bd8dbb 100644
--- a/libs/ai-endpoints/tests/integration_tests/test_structured_output.py
+++ b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py
@@ -1,9 +1,10 @@
import enum
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, Literal, Optional, Union
import pytest
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
+from pydantic import BaseModel as BaseModelProper
from langchain_nvidia_ai_endpoints import ChatNVIDIA
@@ -20,6 +21,26 @@ def do_stream(llm: ChatNVIDIA, message: str) -> Any:
return result[-1] if result else None
+class Joke(BaseModelProper):
+ """Joke to tell user."""
+
+ setup: str = Field(description="question to set up a joke")
+ punchline: str = Field(description="answer to resolve the joke")
+
+
+class SelfEvaluation(BaseModelProper):
+ score: int
+ text: str
+
+
+class JokeWithEvaluation(BaseModelProper):
+ """Joke to tell user."""
+
+ setup: str
+ punchline: str
+ self_evaluation: SelfEvaluation
+
+
@pytest.mark.xfail(reason="Accuracy is not guaranteed")
def test_accuracy(structured_model: str, mode: dict) -> None:
class Person(BaseModel):
@@ -66,14 +87,6 @@ class Person(BaseModel):
assert person.birthplace == "Tainan, Taiwan"
-class Joke(BaseModel):
- """Joke to tell user."""
-
- setup: str = Field(description="The setup of the joke")
- punchline: str = Field(description="The punchline to the joke")
- rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
-
-
@pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"])
def test_pydantic(structured_model: str, mode: dict, func: Callable) -> None:
llm = ChatNVIDIA(model=structured_model, temperature=0, **mode)
@@ -184,3 +197,142 @@ def test_pydantic_incomplete(structured_model: str, mode: dict, func: Callable)
structured_llm = llm.with_structured_output(Joke)
result = func(structured_llm, "Tell me a joke about cats")
assert result is None
+
+
+def joke(result: Any) -> None:
+ assert isinstance(result, dict)
+ assert all(key in set(result.keys()) for key in {"setup", "punchline"})
+
+
+def nested_json(result: Any) -> None:
+ assert isinstance(result, dict) # for mypy
+ assert set(result.keys()) == {"setup", "punchline", "self_evaluation"}
+ assert set(result["self_evaluation"].keys()) == {"score", "text"}
+
+
+@pytest.mark.parametrize(
+ ("method", "strict"),
+ [("json_schema", None), ("json_mode", None)],
+)
+def test_structured_output_json_strict(
+ structured_model: str,
+ mode: dict,
+ method: Literal["json_mode", "json_schema"],
+ strict: Optional[bool],
+) -> None:
+ """Test to verify structured output with strict=True."""
+
+ llm = ChatNVIDIA(model=structured_model, temperature=0, **mode)
+
+ # Test structured output with a Pydantic class
+ chat = llm.with_structured_output(Joke, method=method, strict=strict)
+ result = chat.invoke("Tell me a joke about cats.")
+
+ assert isinstance(result, Joke)
+
+ for chunk in chat.stream("Tell me a joke about cats."):
+ assert isinstance(chunk, Joke)
+
+ # Test structured output with JSON schema
+ chat = llm.with_structured_output(
+ Joke.model_json_schema(), method=method, strict=strict
+ )
+ result = chat.invoke("Tell me a joke about cats.")
+ joke(result)
+
+ for chunk in chat.stream("Tell me a joke about cats."):
+ assert isinstance(chunk, dict)
+ joke(chunk)
+
+
+@pytest.mark.parametrize(
+ ("method", "strict"), [("json_schema", None), ("json_mode", None)]
+)
+def test_nested_structured_output_json_strict(
+ structured_model: str,
+ mode: dict,
+ method: Literal["json_schema", "json_mode"],
+ strict: Optional[bool],
+) -> None:
+ """Test to verify structured output with strict=True for nested object."""
+
+ llm = ChatNVIDIA(model=structured_model, temperature=0, **mode)
+
+ # Schema
+ chat = llm.with_structured_output(
+ JokeWithEvaluation.model_json_schema(), method=method, strict=strict
+ )
+ result = chat.invoke("Tell me a joke about cats.")
+ nested_json(result)
+
+ for chunk in chat.stream("Tell me a joke about cats."):
+ assert isinstance(chunk, dict)
+ nested_json(chunk)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("method", "strict"),
+ [("json_schema", None), ("json_mode", None)],
+)
+async def test_structured_output_json_strict_async(
+ structured_model: str,
+ method: Literal["json_schema", "json_mode"],
+ strict: Optional[bool],
+) -> None:
+ """Test to verify structured output with strict=True (async)."""
+
+ llm = ChatNVIDIA(model=structured_model, temperature=0)
+
+ # Pydantic class
+ chat = llm.with_structured_output(Joke, method=method, strict=strict)
+ result = await chat.ainvoke("Tell me a joke about cats.")
+ assert isinstance(result, Joke)
+
+ async for chunk in chat.astream("Tell me a joke about cats."):
+ assert isinstance(chunk, Joke)
+
+ # Schema
+ chat = llm.with_structured_output(
+ Joke.model_json_schema(), method=method, strict=strict
+ )
+ result = await chat.ainvoke("Tell me a joke about cats.")
+ joke(result)
+
+ async for chunk in chat.astream("Tell me a joke about cats."):
+ assert isinstance(chunk, dict)
+ joke(chunk)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("method", "strict"), [("json_schema", None), ("json_mode", None)]
+)
+async def test_nested_structured_output_json_strict_async(
+ structured_model: str, method: Literal["json_schema"], strict: Optional[bool]
+) -> None:
+ """Test to verify structured output with strict=True for nested object (async)."""
+
+ llm = ChatNVIDIA(model=structured_model, temperature=0)
+
+ # Schema
+ chat = llm.with_structured_output(
+ JokeWithEvaluation.model_json_schema(), method=method, strict=strict
+ )
+ result = await chat.ainvoke("Tell me a joke about cats.")
+ nested_json(result)
+
+ async for chunk in chat.astream("Tell me a joke about cats."):
+ assert isinstance(chunk, dict)
+ nested_json(chunk)
+
+
+def test_json_mode_with_dict(structured_model: str) -> None:
+ """Test json_mode with a dictionary schema."""
+ schema = {
+ "type": "object",
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
+ }
+
+ llm = ChatNVIDIA(model=structured_model)
+ llm.with_structured_output(schema, method="json_mode")
diff --git a/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py b/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py
index a2a26436..920e32e0 100644
--- a/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py
+++ b/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py
@@ -342,3 +342,39 @@ def test_strict_no_warns(strict: Optional[bool]) -> None:
tools=[xxyyzz_tool_annotated],
**({"strict": strict} if strict is not None else {}),
)
+
+
+def test_json_mode(
+ requests_mock: requests_mock.Mocker,
+ mock_v1_models: None,
+) -> None:
+ requests_mock.post(
+ "https://integrate.api.nvidia.com/v1/chat/completions",
+ json={
+ "id": "chatcmpl-ID",
+ "object": "chat.completion",
+ "created": 1234567890,
+ "model": "BOGUS",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": '{"a": 1}',
+ },
+ "logprobs": None,
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 22,
+ "completion_tokens": 20,
+ "total_tokens": 42,
+ },
+ "system_fingerprint": None,
+ },
+ )
+
+ llm = ChatNVIDIA(api_key="BOGUS").bind(response_format={"type": "json_object"})
+ response = llm.invoke("Return this as json: {'a': 1}")
+ assert isinstance(response, AIMessage)
+ assert json.loads(str(response.content)) == {"a": 1}
diff --git a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py
index c8dfd00b..744a6f82 100644
--- a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py
+++ b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py
@@ -19,19 +19,6 @@ class Joke(pydanticV2BaseModel):
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
-def test_method() -> None:
- with pytest.warns(UserWarning) as record:
- with warnings.catch_warnings():
- warnings.filterwarnings(
- "ignore",
- category=UserWarning,
- message=".*not known to support structured output.*",
- )
- ChatNVIDIA(api_key="BOGUS").with_structured_output(Joke, method="json_mode")
- assert len(record) == 1
- assert "unnecessary" in str(record[0].message)
-
-
def test_include_raw() -> None:
with pytest.raises(NotImplementedError):
ChatNVIDIA(api_key="BOGUS").with_structured_output(Joke, include_raw=True)
@@ -226,7 +213,5 @@ def test_strict_no_warns(strict: Optional[bool]) -> None:
"ignore", category=UserWarning, message=".*not known to support.*"
)
- ChatNVIDIA(api_key="BOGUS").with_structured_output(
- Joke,
- **({"strict": strict} if strict is not None else {}),
- )
+ if strict:
+ ChatNVIDIA(api_key="BOGUS").with_structured_output(Joke, strict=strict)