From 80b6f1980f2aa9004d07aa0a4461af3ad47961ef Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Tue, 26 Nov 2024 02:29:53 +0000 Subject: [PATCH 01/40] Add Spanner Graph QA Chain --- README.rst | 29 + docs/graph_qa_chain.ipynb | 689 +++++++++++++++++++++ requirements.txt | 5 +- src/langchain_google_spanner/__init__.py | 2 + src/langchain_google_spanner/graph_qa.py | 396 ++++++++++++ src/langchain_google_spanner/prompts.py | 230 +++++++ tests/integration/test_spanner_graph_qa.py | 195 ++++++ 7 files changed, 1545 insertions(+), 1 deletion(-) create mode 100644 docs/graph_qa_chain.ipynb create mode 100644 src/langchain_google_spanner/graph_qa.py create mode 100644 src/langchain_google_spanner/prompts.py create mode 100644 tests/integration/test_spanner_graph_qa.py diff --git a/README.rst b/README.rst index 238ebce..5111312 100644 --- a/README.rst +++ b/README.rst @@ -151,6 +151,35 @@ See the full `Spanner Graph Store`_ tutorial. .. _`Spanner Graph Store`: https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_store.ipynb +Spanner Graph QA Usage +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use ``SpannerGraphQAChain`` for question answering over a graph stored in Spanner Graph. + +.. code:: python + + from langchain_google_spanner import SpannerGraphQAChain + from langchain_google_spanner import SpannerGraphStore + from langchain_google_vertexai import ChatVertexAI + + + graph = SpannerGraphStore( + instance_id="my-instance", + database_id="my-database", + graph_name="my_graph", + ) + llm = ChatVertexAI() + chain = SpannerGraphQAChain.from_llm( + llm, + graph=graph, + allow_dangerous_requests=True + ) + chain.invoke("query=Where does Sarah's sibling live?") + +See the full `Spanner Graph QA Chain`_ tutorial. + +.. _`Spanner Graph QA Chain`: https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_qa_chain.ipynb + Contributions ~~~~~~~~~~~~~ diff --git a/docs/graph_qa_chain.ipynb b/docs/graph_qa_chain.ipynb new file mode 100644 index 0000000..e6a5716 --- /dev/null +++ b/docs/graph_qa_chain.ipynb @@ -0,0 +1,689 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "colab": { + "provenance": [], + "toc_visible": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Google Spanner\n", + "\n", + "> [Spanner](https://cloud.google.com/spanner) is a highly scalable database that combines unlimited scalability with relational semantics, such as secondary indexes, strong consistency, schemas, and SQL providing 99.999% availability in one easy solution.\n", + "\n", + "This notebook goes over how to use `Spanner` for GraphRAG with `SpannerGraphStore` and `SpannerGraphQAChain` class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-spanner-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-spanner-python/blob/main/docs/graph_store.ipynb)" + ], + "metadata": { + "id": "7VBkjcqNNxEd" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Before You Begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the Cloud Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com)\n", + " * [Create a Spanner instance](https://cloud.google.com/spanner/docs/create-manage-instances)\n", + " * [Create a Spanner database](https://cloud.google.com/spanner/docs/create-manage-databases)" + ], + "metadata": { + "id": "HEAGYTPgNydh" + } + }, + { + "cell_type": "markdown", + "source": [ + "### 🦜🔗 Library Installation\n", + "The integration lives in its own `langchain-google-spanner` package, so we need to install it." + ], + "metadata": { + "id": "cboPIg-yOcxS" + } + }, + { + "cell_type": "code", + "source": [ + "%pip install --upgrade --quiet langchain-google-spanner langchain-google-vertexai langchain-experimental json-repair pyvis" + ], + "metadata": { + "id": "AOWh6QKYVdDp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ], + "metadata": { + "id": "M7MqpDhkOiP-" + } + }, + { + "cell_type": "code", + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "import IPython\n", + "\n", + "app = IPython.Application.instance()\n", + "app.kernel.do_shutdown(True)" + ], + "metadata": { + "id": "xzgVZv0POj17" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### 🔐 Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ], + "metadata": { + "id": "zfIhwIryOls1" + } + }, + { + "cell_type": "code", + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ], + "metadata": { + "id": "EWOkHI7XOna2" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ], + "metadata": { + "id": "6xHXneICOpsB" + } + }, + { + "cell_type": "code", + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"google.com:cloud-spanner-demo\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}\n", + "%env GOOGLE_CLOUD_PROJECT={PROJECT_ID}" + ], + "metadata": { + "id": "hF0481BGOsS8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### 💡 API Enablement\n", + "The `langchain-google-spanner` package requires that you [enable the Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com) in your Google Cloud Project." + ], + "metadata": { + "id": "4TiC0RbhOwUu" + } + }, + { + "cell_type": "code", + "source": [ + "# enable Spanner API\n", + "!gcloud services enable spanner.googleapis.com" + ], + "metadata": { + "id": "9f3fJd5eOyRr" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "You must also and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com)." + ], + "metadata": { + "id": "bT_S-jaEOW4P" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic Usage" + ], + "metadata": { + "id": "k5pxMMiMOzt7" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Set Spanner database values\n", + "Find your database values, in the [Spanner Instances page](https://console.cloud.google.com/spanner?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." + ], + "metadata": { + "id": "mtDbLU5sO2iA" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "INSTANCE = \"\" # @param {type: \"string\"}\n", + "DATABASE = \"\" # @param {type: \"string\"}\n", + "GRAPH_NAME = \"\" # @param {type: \"string\"}" + ], + "metadata": { + "id": "C-I8VTIcO442" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### SpannerGraphStore\n", + "\n", + "To initialize the `SpannerGraphStore` class you need to provide 3 required arguments and other arguments are optional and only need to pass if it's different from default ones\n", + "\n", + "1. a Spanner instance id;\n", + "2. a Spanner database id belongs to the above instance id;\n", + "3. a Spanner graph name used to create a graph in the above database." + ], + "metadata": { + "id": "kpAv-tpcO_iL" + } + }, + { + "cell_type": "code", + "source": [ + "from langchain_google_spanner import SpannerGraphStore\n", + "\n", + "graph_store = SpannerGraphStore(\n", + " instance_id=INSTANCE,\n", + " database_id=DATABASE,\n", + " graph_name=GRAPH_NAME,\n", + ")" + ], + "metadata": { + "id": "u589YapWQFb8" + }, + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "#### Add Graph Documents to Spanner Graph" + ], + "metadata": { + "id": "G7-Pe2ADQlNJ" + } + }, + { + "cell_type": "code", + "source": [ + " # @title Extract Nodes and Edges from text snippets\n", + "from langchain_core.documents import Document\n", + "from langchain_experimental.graph_transformers import LLMGraphTransformer\n", + "from langchain_google_vertexai import ChatVertexAI\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "\n", + "text_snippets = [\n", + "# Text snippet for students graduting from Veritas University, Computer Science Dept 2017\n", + "\n", + "\"\"\"\n", + "This was the graduation ceremony of 2017. A wave of jubilant graduates poured out of the\n", + "grand halls of Veritas University, their laughter echoing across the quad. Among them were\n", + "a cohort of exceptional students from the Computer Science department, a group that had\n", + "become known for their collaborative spirit and innovative ideas.\n", + "Leading the pack was Emily Davis, a coding whiz with a passion for cybersecurity, already\n", + "fielding offers from top tech firms. Beside her walked James Rodriguez, a quiet but\n", + "brilliant mind fascinated by artificial intelligence, dreaming of building machines that\n", + "could understand human emotions. Trailing slightly behind, deep in conversation, were\n", + "Sarah Chen and Michael Patel, both aspiring game developers, eager to bring their creative\n", + "visions to life. And then there was Aisha Khan, a social justice advocate who planned to\n", + "use her coding skills to address inequality through technology.\n", + "As they celebrated their achievements, these Veritas University Computer Science graduates\n", + "were ready to embark on diverse paths, each carrying the potential to shape the future of\n", + "technology in their own unique way.\n", + "\"\"\",\n", + "\n", + "# Text snippet for students graduting from Oakhaven University, Computer Science Dept 2016\n", + "\"\"\"\n", + "The year was 2016, and a palpable buzz filled the air as the graduating class of Oakhaven\n", + "university from Computer science and Engineering department emerged from the Beckman\n", + "Auditorium. Among them was a group of exceptional students, renowned for their\n", + "intellectual curiosity and groundbreaking research.\n", + "At the forefront was Alice Johnson, a gifted programmer with a fascination for quantum\n", + "computing, already collaborating with leading researchers in the field. Beside her\n", + "strode David Kim, a brilliant theorist captivated by the intricacies of cryptography,\n", + "eager to contribute to the development of secure communication systems. Engaged in an\n", + "animated discussion were Maria Rodriguez and Robert Lee, both passionate about robotics\n", + "and determined to push the boundaries of artificial intelligence. And then there was\n", + "Chloe Brown, a visionary with a deep interest in bioinformatics, driven to unlock the\n", + "secrets of the human genome through computational analysis.\n", + "As they celebrated their accomplishments, these graduates, armed with their exceptional\n", + "skills and unwavering determination, were poised to make significant contributions to the world of computing and beyond.\n", + "\"\"\",\n", + "\n", + "# Text snippet mentions the company Emily Davis founded.\n", + "# The snippet doesn't mention that she is an alumni of Veritas University\n", + "\"\"\"\n", + "Emily Davis, a name synonymous with cybersecurity innovation, turned that passion into a\n", + "thriving business. In the year 2022, Davis founded Ironclad Security, a company that's\n", + "rapidly changing the landscape of cybersecurity solutions.\n", + "\"\"\",\n", + "\n", + "# Text snippet mentions the company Alice Johnson founded.\n", + "# The snippet doesn't mention that she is an alumni of Oakhaven University.\n", + "\"\"\"\n", + "Alice Johnson had a vision that extended far beyond the classroom. Driven by an insatiable\n", + "curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a\n", + "company poised to revolutionize industries through the power of quantum technology.\n", + "Entangled Solutions distinguishes itself by focusing on practical applications of quantum\n", + "computing.\n", + "\"\"\"\n", + "]\n", + "\n", + "# Create splits for documents\n", + "documents = [Document(page_content=t) for t in text_snippets]\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", + "splits = text_splitter.split_documents(documents)\n", + "llm = ChatVertexAI(model=\"gemini-1.5-flash\", temperature=0)\n", + "llm_transformer = LLMGraphTransformer(\n", + " llm=llm,\n", + " allowed_nodes = [\"College\", \"Deparatment\", \"Person\", \"Year\", \"Company\"],\n", + " allowed_relationships = [\"AlumniOf\", \"StudiedInDepartment\", \"PartOf\", \"GraduatedInYear\", \"Founded\"],\n", + " node_properties=[ \"description\", ],\n", + ")\n", + "graph_documents = llm_transformer.convert_to_graph_documents(splits)\n" + ], + "metadata": { + "id": "fP7XNu3aPl5c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Print extracted nodes and edges\n", + "for doc in graph_documents:\n", + " print(doc.source.page_content[:100])\n", + " print(doc.nodes)\n", + " print(doc.relationships)\n", + " print()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OylyNyv-ZsT2", + "outputId": "e4253d98-ad63-4ea8-a5f1-0e3dac8f6632" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "This was the graduation ceremony of 2017. A wave of jubilant graduates poured out of the\n", + "grand halls\n", + "[Node(id='Veritas University', type='College', properties={'description': 'grand halls'}), Node(id='Computer Science', type='Deparatment', properties={}), Node(id='2017', type='Year', properties={}), Node(id='Emily Davis', type='Person', properties={'description': 'coding whiz with a passion for cybersecurity'}), Node(id='James Rodriguez', type='Person', properties={'description': 'quiet but brilliant mind fascinated by artificial intelligence'}), Node(id='Sarah Chen', type='Person', properties={'description': 'aspiring game developers'}), Node(id='Michael Patel', type='Person', properties={'description': 'aspiring game developers'}), Node(id='Aisha Khan', type='Person', properties={'description': 'social justice advocate'})]\n", + "[Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={})]\n", + "\n", + "visions to life. And then there was Aisha Khan, a social justice advocate who planned to\n", + "use her c\n", + "[Node(id='Veritas University', type='College', properties={}), Node(id='Computer Science', type='Deparatment', properties={}), Node(id='Aisha Khan', type='Person', properties={'description': 'social justice advocate'})]\n", + "[Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={})]\n", + "\n", + "The year was 2016, and a palpable buzz filled the air as the graduating class of Oakhaven\n", + "university\n", + "[Node(id='Oakhaven University', type='College', properties={'description': 'Oakhaven university'}), Node(id='Computer Science And Engineering', type='Deparatment', properties={'description': 'Computer science and Engineering'}), Node(id='2016', type='Year', properties={'description': '2016'}), Node(id='Alice Johnson', type='Person', properties={'description': 'a gifted programmer with a fascination for quantum computing, already collaborating with leading researchers in the field'}), Node(id='David Kim', type='Person', properties={'description': 'a brilliant theorist captivated by the intricacies of cryptography, eager to contribute to the development of secure communication systems'}), Node(id='Maria Rodriguez', type='Person', properties={'description': 'passionate about robotics and determined to push the boundaries of artificial intelligence'}), Node(id='Robert Lee', type='Person', properties={'description': 'passionate about robotics and determined to push the boundaries of artificial intelligence'}), Node(id='Chloe Brown', type='Person', properties={'description': 'a visionary with a deep interest in bioinformatics, driven to unlock the secrets of the human genome through computational analysis'}), Node(id='Beckman Auditorium', type='Deparatment', properties={'description': 'Beckman Auditorium'})]\n", + "[Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='PARTOF', properties={}), Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Oakhaven University', type='College', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Oakhaven University', type='College', properties={}), target=Node(id='Beckman Auditorium', type='Deparatment', properties={}), type='PARTOF', properties={})]\n", + "\n", + "Chloe Brown, a visionary with a deep interest in bioinformatics, driven to unlock the\n", + "secrets of the\n", + "[Node(id='Chloe Brown', type='Person', properties={'description': 'a visionary with a deep interest in bioinformatics, driven to unlock the secrets of the human genome through computational analysis'})]\n", + "[]\n", + "\n", + "Emily Davis, a name synonymous with cybersecurity innovation, turned that passion into a\n", + "thriving bu\n", + "[Node(id='Emily Davis', type='Person', properties={'description': 'a name synonymous with cybersecurity innovation'}), Node(id='Ironclad Security', type='Company', properties={'description': \"a company that's rapidly changing the landscape of cybersecurity solutions\"}), Node(id='2022', type='Year', properties={})]\n", + "[Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Ironclad Security', type='Company', properties={}), type='FOUNDED', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='2022', type='Year', properties={}), type='FOUNDED', properties={})]\n", + "\n", + "Alice Johnson had a vision that extended far beyond the classroom. Driven by an insatiable\n", + "curiosity\n", + "[Node(id='Alice Johnson', type='Person', properties={'description': 'Driven by an insatiable curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a company poised to revolutionize industries through the power of quantum technology.'}), Node(id='Entangled Solutions', type='Company', properties={'description': 'Entangled Solutions distinguishes itself by focusing on practical applications of quantum computing.'})]\n", + "[Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Entangled Solutions', type='Company', properties={}), type='FOUNDED', properties={})]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Load the graph to Spanner Graph database\n", + "# Uncomment the line below, if you want to cleanup from\n", + "# previous iterations.\n", + "# BeWARE - THIS COULD REMOVE DATA FROM YOUR DATABASE !!!\n", + "# graph_store.cleanup()\n", + "\n", + "\n", + "for graph_document in graph_documents:\n", + " graph_store.add_graph_documents([graph_document])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lMXvOpRbZdau", + "outputId": "26647456-2316-46e3-de43-cfc9845a1050" + }, + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Waiting for DDL operations to complete...\n", + "Insert nodes of type `College`...\n", + "Insert nodes of type `Deparatment`...\n", + "Insert nodes of type `Year`...\n", + "Insert nodes of type `Person`...\n", + "Insert edges of type `Person_ALUMNIOF_College`...\n", + "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", + "Insert edges of type `Person_GRADUATEDINYEAR_Year`...\n", + "No schema change required...\n", + "Insert nodes of type `College`...\n", + "Insert nodes of type `Deparatment`...\n", + "Insert nodes of type `Person`...\n", + "Insert edges of type `Person_ALUMNIOF_College`...\n", + "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", + "Waiting for DDL operations to complete...\n", + "Insert nodes of type `College`...\n", + "Insert nodes of type `Deparatment`...\n", + "Insert nodes of type `Year`...\n", + "Insert nodes of type `Person`...\n", + "Insert edges of type `Person_ALUMNIOF_College`...\n", + "Insert edges of type `Deparatment_PARTOF_College`...\n", + "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", + "Insert edges of type `College_GRADUATEDINYEAR_Year`...\n", + "Insert edges of type `Person_GRADUATEDINYEAR_Year`...\n", + "Insert edges of type `College_PARTOF_Deparatment`...\n", + "No schema change required...\n", + "Insert nodes of type `Person`...\n", + "Waiting for DDL operations to complete...\n", + "Insert nodes of type `Person`...\n", + "Insert nodes of type `Company`...\n", + "Insert nodes of type `Year`...\n", + "Insert edges of type `Person_FOUNDED_Company`...\n", + "Insert edges of type `Person_FOUNDED_Year`...\n", + "No schema change required...\n", + "Insert nodes of type `Person`...\n", + "Insert nodes of type `Company`...\n", + "Insert edges of type `Person_FOUNDED_Company`...\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Initialize the Spanner Graph QA Chain\n", + "The Spanner Graph QA Chain takes two parameters, a SpannerGraphStore object and a language model." + ], + "metadata": { + "id": "qlKwtdGN7kaT" + } + }, + { + "cell_type": "code", + "source": [ + "from google.cloud import spanner\n", + "from langchain_google_vertexai import ChatVertexAI\n", + "from IPython.core.display import HTML\n", + "\n", + "# Initialize llm object\n", + "llm = ChatVertexAI(model=\"gemini-1.5-flash-002\", temperature=0)\n", + "\n", + "# Initialize GraphQAChain\n", + "chain = SpannerGraphQAChain.from_llm(\n", + " llm,\n", + " graph=graph_store,\n", + " allow_dangerous_requests=True,\n", + " verbose=True,\n", + " return_intermediate_steps=True\n", + ")" + ], + "metadata": { + "id": "7yKDAD9s7t7O" + }, + "execution_count": 30, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Run Spanner Graph QA Chain 1\n", + "question = \"Who are the alumni of the college id Veritas University ?\" # @param {type:\"string\"}\n", + "response = chain.invoke(\"query=\" + question)\n", + "response[\"result\"]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 264 + }, + "id": "ukKi9wtH_bF1", + "outputId": "61b66dcb-54cf-4620-a097-b4f0d732d1e3" + }, + "execution_count": 33, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", + "Executing gql:\n", + "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", + "MATCH (p:Person)-[:ALUMNIOF]->(c:College {id: \"Veritas University\"})\n", + "RETURN p.id AS person_id, c.id AS college_id\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'person_id': 'Aisha Khan', 'college_id': 'Veritas University'}, {'person_id': 'Emily Davis', 'college_id': 'Veritas University'}, {'person_id': 'James Rodriguez', 'college_id': 'Veritas University'}, {'person_id': 'Michael Patel', 'college_id': 'Veritas University'}, {'person_id': 'Sarah Chen', 'college_id': 'Veritas University'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'Aisha Khan, Emily Davis, James Rodriguez, Michael Patel, and Sarah Chen are alumni of Veritas University.\\n'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 33 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Run Spanner Graph QA Chain 2\n", + "question = \"List the companies, their founders and the college they attended.\" # @param {type:\"string\"}\n", + "response = chain.invoke(\"query=\" + question)\n", + "response[\"result\"]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "outputId": "e47d9f63-6769-49bc-b3a3-412c10de5c8a", + "id": "lcBc4tG__7Rm" + }, + "execution_count": 34, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", + "Executing gql:\n", + "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", + "MATCH (p:Person)-[:FOUNDED]->(c:Company), (p)-[:ALUMNIOF]->(cl:College)\n", + "RETURN c.id AS company_id, c.description AS company_description, p.id AS founder_id, p.description AS founder_description, cl.id AS college_id, cl.description AS college_description\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'company_id': 'Entangled Solutions', 'company_description': 'Entangled Solutions distinguishes itself by focusing on practical applications of quantum computing.', 'founder_id': 'Alice Johnson', 'founder_description': 'Driven by an insatiable curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a company poised to revolutionize industries through the power of quantum technology.', 'college_id': 'Oakhaven University', 'college_description': 'Oakhaven university'}, {'company_id': 'Ironclad Security', 'company_description': \"a company that's rapidly changing the landscape of cybersecurity solutions\", 'founder_id': 'Emily Davis', 'founder_description': 'a name synonymous with cybersecurity innovation', 'college_id': 'Veritas University', 'college_description': 'grand halls'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'Entangled Solutions, founded by Alice Johnson who attended Oakhaven University, focuses on practical applications of quantum computing. Ironclad Security, founded by Emily Davis who attended Veritas University, is rapidly changing the landscape of cybersecurity solutions.\\n'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 34 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Run Spanner Graph QA Chain 3\n", + "question = \"Which companies were founded by alumni of college id Veritas University ? Who were the founders ?\" # @param {type:\"string\"}\n", + "response = chain.invoke(\"query=\" + question)\n", + "response[\"result\"]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 264 + }, + "outputId": "cb40179e-bcec-4399-df9d-a114e02b33f9", + "id": "e6djmq1NAGOM" + }, + "execution_count": 35, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", + "Executing gql:\n", + "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", + "MATCH (c:College {id: \"Veritas University\"})<-[:ALUMNIOF]-(p:Person)-[:FOUNDED]->(co:Company)\n", + "RETURN co.id AS company_id, co.description AS company_description, p.id AS founder_id, p.description AS founder_description\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'company_id': 'Ironclad Security', 'company_description': \"a company that's rapidly changing the landscape of cybersecurity solutions\", 'founder_id': 'Emily Davis', 'founder_description': 'a name synonymous with cybersecurity innovation'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "\"Ironclad Security, a company that's rapidly changing the landscape of cybersecurity solutions, was founded by Emily Davis, a name synonymous with cybersecurity innovation.\\n\"" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 35 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "#### Clean up the graph\n", + "\n", + "> USE IT WITH CAUTION!\n", + "\n", + "Clean up all the nodes/edges in your graph and remove your graph definition." + ], + "metadata": { + "id": "pM7TmfI0TEFy" + } + }, + { + "cell_type": "code", + "source": [ + "graph_store.cleanup()" + ], + "metadata": { + "id": "UQWq4-sITOgl" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ec56696..3af70b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +google-cloud-spanner==3.49.1 langchain-core==0.3.9 langchain-community==0.3.1 -google-cloud-spanner==3.49.1 +langchain-experimental==0.3.2 +langchain_google_vertexai +pydantic==2.9.1 diff --git a/src/langchain_google_spanner/__init__.py b/src/langchain_google_spanner/__init__.py index fb19446..5f5b5ae 100644 --- a/src/langchain_google_spanner/__init__.py +++ b/src/langchain_google_spanner/__init__.py @@ -14,6 +14,7 @@ from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory from langchain_google_spanner.graph_store import SpannerGraphStore +from langchain_google_spanner.graph_qa import SpannerGraphQAChain from langchain_google_spanner.vector_store import ( DistanceStrategy, QueryParameters, @@ -32,6 +33,7 @@ "SpannerDocumentSaver", "SpannerLoader", "SpannerGraphStore", + "SpannerGraphQAChain", "TableColumn", "SecondaryIndex", "QueryParameters", diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py new file mode 100644 index 0000000..14d546f --- /dev/null +++ b/src/langchain_google_spanner/graph_qa.py @@ -0,0 +1,396 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import JsonOutputParser, StrOutputParser +from langchain_core.prompts import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.runnables import RunnableSequence +from pydantic.v1 import BaseModel, Field + +from .graph_store import SpannerGraphStore +from .prompts import ( + DEFAULT_GQL_FIX_TEMPLATE, + DEFAULT_GQL_TEMPLATE, + DEFAULT_GQL_VERIFY_TEMPLATE, + SPANNERGRAPH_QA_TEMPLATE, +) + +GQL_GENERATION_PROMPT = PromptTemplate( + template=DEFAULT_GQL_TEMPLATE, + input_variables=["question", "schema"], +) + + +class VerifyGqlOutput(BaseModel): + input_gql: str + made_change: bool + explanation: str + verified_gql: str + + +verify_gql_output_parser = JsonOutputParser(pydantic_object=VerifyGqlOutput) + +GQL_VERIFY_PROMPT = PromptTemplate( + template=DEFAULT_GQL_VERIFY_TEMPLATE, + input_variables=["question", "generated_gql", "graph_schema"], + partial_variables={ + "format_instructions": verify_gql_output_parser.get_format_instructions() + }, +) + +GQL_FIX_PROMPT = PromptTemplate( + template=DEFAULT_GQL_FIX_TEMPLATE, + input_variables=["question", "generated_gql", "err_msg", "schema"], +) + +SPANNERGRAPH_QA_PROMPT = PromptTemplate( + template=SPANNERGRAPH_QA_TEMPLATE, + input_variables=["question", "graph_schema", "graph_query", "context"], +) + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def fix_gql_syntax(query: str) -> str: + """Fixes the syntax of a GQL query. + + Args: + query: The input GQL query. + + Returns: + Possibly modified GQL query. + """ + + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]->", r"-[\1:\2]->{\3,\4}", query) + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]->", r"-[\1:\2]->{\3}", query) + query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"<-[\1:\2]-{\3,\4}", query) + query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\]-", r"<-[\1:\2]-{\3}", query) + return query + + +def extract_gql(text: str) -> str: + """Extract GQL query from a text. + + Args: + text: Text to extract GQL query from. + + Returns: + GQL query extracted from the text. + """ + pattern = r"```(.*?)```" + matches = re.findall(pattern, text, re.DOTALL) + query = matches[0] if matches else text + return fix_gql_syntax(query) + + +def extract_verified_gql(json_response: str) -> str: + """Extract GQL query from a LLM response. + + Args: + response: Response to extract GQL query from. + + Returns: + GQL query extracted from the text. + """ + + json_response["verified_gql"] = fix_gql_syntax(str(json_response["verified_gql"])) + return json_response["verified_gql"] + + +class SpannerGraphQAChain(Chain): + """Chain for question-answering against a Spanner Graph database by + generating GQL statements from natural language questions. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as + appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: SpannerGraphStore = Field(exclude=True) + gql_generation_chain: RunnableSequence + gql_fix_chain: RunnableSequence + gql_verify_chain: RunnableSequence + qa_chain: RunnableSequence + max_gql_fix_retries: int = 1 + """ Number of retries to fix an errornous generated graph query.""" + top_k: int = 10 + """Restricts the number of results returned in the graph query.""" + return_intermediate_steps: bool = False + """Whether to return the intermediate steps along with the final answer.""" + verify_gql: bool = True + """Whether to have a stage in the chain to verify and fix the generated GQL.""" + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + allow_dangerous_requests: bool = False + """Forced user opt-in to acknowledge that the chain can make dangerous requests. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize the chain.""" + super().__init__(**kwargs) + if not self.allow_dangerous_requests: + raise ValueError( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`." + "You must narrowly scope the permissions of the database connection " + "to only include necessary permissions. Failure to do so may result " + "in data corruption or loss or reading sensitive data if such data is " + "present in the database. " + "Only use this chain if you understand the risks and have taken the " + "necessary precautions. " + "See https://python.langchain.com/docs/security for more information." + ) + + @property + def input_keys(self) -> List[str]: + """Input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel = None, + *, + qa_prompt: BasePromptTemplate = None, + gql_prompt: BasePromptTemplate = None, + gql_verify_prompt: BasePromptTemplate = None, + gql_fix_prompt: BasePromptTemplate = None, + qa_llm_kwargs: Optional[Dict[str, Any]] = None, + gql_llm_kwargs: Optional[Dict[str, Any]] = None, + gql_verify_llm_kwargs: Optional[Dict[str, Any]] = None, + gql_fix_llm_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> SpannerGraphQAChain: + """Initialize from LLM.""" + if not llm: + raise ValueError("`llm` parameter must be provided") + if gql_prompt and gql_llm_kwargs: + raise ValueError( + "Specifying gql_prompt and gql_llm_kwargs together is" + " not allowed. Please pass prompt via gql_llm_kwargs." + ) + if gql_fix_prompt and gql_fix_llm_kwargs: + raise ValueError( + "Specifying gql_fix_prompt and gql_fix_llm_kwargs together is" + " not allowed. Please pass prompt via gql_fix_llm_kwargs." + ) + if qa_prompt and qa_llm_kwargs: + raise ValueError( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) + + use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} + use_gql_llm_kwargs = gql_llm_kwargs if gql_llm_kwargs is not None else {} + use_gql_verify_llm_kwargs = ( + gql_verify_llm_kwargs if gql_verify_llm_kwargs is not None else {} + ) + use_gql_fix_llm_kwargs = ( + gql_fix_llm_kwargs if gql_fix_llm_kwargs is not None else {} + ) + + if "prompt" not in use_qa_llm_kwargs: + use_qa_llm_kwargs["prompt"] = ( + qa_prompt if qa_prompt is not None else SPANNERGRAPH_QA_PROMPT + ) + if "prompt" not in use_gql_llm_kwargs: + use_gql_llm_kwargs["prompt"] = ( + gql_prompt if gql_prompt is not None else GQL_GENERATION_PROMPT + ) + if "prompt" not in use_gql_verify_llm_kwargs: + use_gql_verify_llm_kwargs["prompt"] = ( + gql_verify_prompt + if gql_verify_prompt is not None + else GQL_VERIFY_PROMPT + ) + if "prompt" not in use_gql_fix_llm_kwargs: + use_gql_fix_llm_kwargs["prompt"] = ( + gql_fix_prompt if gql_fix_prompt is not None else GQL_FIX_PROMPT + ) + + gql_generation_chain = use_gql_llm_kwargs["prompt"] | llm | StrOutputParser() + gql_fix_chain = use_gql_fix_llm_kwargs["prompt"] | llm | StrOutputParser() + gql_verify_chain = ( + use_gql_verify_llm_kwargs["prompt"] | llm | verify_gql_output_parser + ) + qa_chain = use_qa_llm_kwargs["prompt"] | llm | StrOutputParser() + + return cls( + gql_generation_chain=gql_generation_chain, + gql_fix_chain=gql_fix_chain, + gql_verify_chain=gql_verify_chain, + qa_chain=qa_chain, + **kwargs, + ) + + def execute_query( + self, _run_manager: CallbackManagerForChainRun, gql_query: str + ) -> List[Any]: + try: + _run_manager.on_text("Executing gql:", end="\n", verbose=self.verbose) + _run_manager.on_text( + gql_query, color="green", end="\n", verbose=self.verbose + ) + return self.graph.query(gql_query)[: self.top_k] + except Exception as e: + raise ValueError(str(e)) + + def execute_with_retry( + self, + _run_manager: CallbackManagerForChainRun, + intermediate_steps: List, + question: str, + gql_query: str, + ) -> tuple[str, List[Any]]: + try: + intermediate_steps.append({"generated_query": gql_query}) + return gql_query, self.execute_query(_run_manager, gql_query) + except Exception as e: + retries = 0 + err_msg = str(e) + self.log_invalid_query(_run_manager, gql_query, err_msg) + intermediate_steps.pop() + intermediate_steps.append({"query_failed_" + str(retries + 1): gql_query}) + + new_gql_query = "" + while retries < self.max_gql_fix_retries: + try: + fix_chain_result = self.gql_fix_chain.invoke( + { + "question": question, + "err_msg": err_msg, + "generated_gql": gql_query, + "schema": self.graph.get_schema, + } + ) + new_gql_query = extract_gql(fix_chain_result) + intermediate_steps.append({"generated_query": new_gql_query}) + return new_gql_query, self.execute_query( + _run_manager, new_gql_query + ) + except Exception as e: + retries += 1 + gql_query = new_gql_query + err_msg = str(e) + self.log_invalid_query(_run_manager, gql_query, err_msg) + raise ValueError("The generated gql query is invalid") + + def log_invalid_query( + self, + _run_manager: CallbackManagerForChainRun, + generated_gql: str, + err_msg: str, + ) -> None: + _run_manager.on_text("Invalid generated gql:", end="\n", verbose=self.verbose) + _run_manager.on_text(generated_gql, color="red", end="\n", verbose=self.verbose) + _run_manager.on_text( + "Query error: ", color="red", end="\n", verbose=self.verbose + ) + _run_manager.on_text(err_msg, color="red", end="\n", verbose=self.verbose) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + + intermediate_steps: List = [] + + """Generate gql statement, uses it to look up in db and answer question.""" + + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + question = inputs[self.input_key] + gen_response = self.gql_generation_chain.invoke( + {"question": question, "schema": self.graph.get_schema}, + ) + generated_gql = extract_gql(gen_response) + + if self.verify_gql: + verify_response = self.gql_verify_chain.invoke( + { + "question": question, + "generated_gql": generated_gql, + "graph_schema": self.graph.get_schema, + } + ) + verified_gql = extract_verified_gql(verify_response) + intermediate_steps.append({"verified_gql": verified_gql}) + else: + verified_gql = generated_gql + + final_gql = "" + if verified_gql: + (final_gql, context) = self.execute_with_retry( + _run_manager, intermediate_steps, question, verified_gql + ) + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + intermediate_steps.append({"context": context}) + else: + context = [] + + qa_result = self.qa_chain.invoke( + { + "question": question, + "graph_schema": self.graph.get_schema, + "graph_query": final_gql, + "context": str(context), + } + ) + chain_result: Dict[str, Any] = {self.output_key: qa_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/src/langchain_google_spanner/prompts.py b/src/langchain_google_spanner/prompts.py new file mode 100644 index 0000000..86202b4 --- /dev/null +++ b/src/langchain_google_spanner/prompts.py @@ -0,0 +1,230 @@ +GQL_EXAMPLES = """ +The following query in backtick matches all persons in the graph FinGraph +whose birthday is before 1990-01-10 and +returns their name and birthday. +``` +GRAPH FinGraph +MATCH (p:Person WHERE p.birthday < '1990-01-10') +RETURN p.name as name, p.birthday as birthday; +``` + +The following query in backtick finds the owner of the account with the most +incoming transfers by chaining multiple graph linear statements together. +``` +GRAPH FinGraph +MATCH (:Account)-[:Transfers]->(account:Account) +RETURN account, COUNT(*) AS num_incoming_transfers +GROUP BY account +ORDER BY num_incoming_transfers DESC +LIMIT 1 + +NEXT + +MATCH (account:Account)<-[:Owns]-(owner:Person) +RETURN account.id AS account_id, owner.name AS owner_name, num_incoming_transfers; +``` + +The following query finds all the destination accounts one to three transfers +away from a source Account with id equal to 7. +``` +GRAPH FinGraph +MATCH (src:Account {{id: 7}})-[e:Transfers]->{{1, 3}}(dst:Account) +RETURN src.id AS src_account_id, dst.id AS dst_account_id; +``` +Carefully note the syntax in the example above for path quantification, +that it is `[e:Transfers]->{{1, 3}}` and NOT `[e:Transfers*1..3]->` +""" + +DEFAULT_GQL_TEMPLATE_PART0 = """ +Create an ISO GQL query for the question using the schema. +{gql_examples} +""" + +DEFAULT_GQL_TEMPLATE_PART1 = """ +Instructions: +Mention the name of the graph at the beginning. +Use only nodes and edge types, and properties included in the schema. +Do not use any node and edge type, or properties not included in the schema. +Always alias RETURN values. + +Question: {question} +Schema: {schema} + +Note: +Do not include any explanations or apologies. +Do not prefix query with `gql` +Do not include any backticks. +Start with GRAPH +Output only the query statement. +Do not output any query that tries to modify or delete data. +""" + +DEFAULT_GQL_TEMPLATE = ( + DEFAULT_GQL_TEMPLATE_PART0.format(gql_examples=GQL_EXAMPLES) + + DEFAULT_GQL_TEMPLATE_PART1 +) + +VERIFY_EXAMPLES = """ +Examples: +1. +question: Which movie has own the Oscar award in 1996? +generated_gql: + GRAPH moviedb + MATCH (m:movie)-[:own_award]->(a:award {{name:"Oscar", year:1996}}) + RETURN m.name + +graph_schema: +{{ +"Edges": {{ + "produced_by": "From movie nodes to producer nodes", + "acts": "From actor nodes to movie nodes", + "has_coacted_with": "From actor nodes to actor nodes", + "own_award": "From actor nodes to award nodes" + }} +}} + +The verified gql fixes the missing node 'actor' + MATCH (m:movie)<-[:acts]-(a:actor)-[:own_award]->(a:award {{name:"Oscar", year:1996}}) + RETURN m.name + +2. +question: Which movies have been produced by production house ABC Movies? +generated_gql: + GRAPH moviedb + MATCH (p:producer {{name:"ABC Movies"}})-[:produced_by]->(m:movie) + RETURN p.name + +graph_schema: +{{ +"Edges": {{ + "produced_by": "From movie nodes to producer nodes", + "acts": "From actor nodes to movie nodes", + "references": "From movie nodes to movie nodes", + "own_award": "From actor nodes to award nodes" + }} +}} + +The verified gql fixes the edge direction: + GRAPH moviedb + MATCH (p:producer {{name:"ABC Movies"}})<-[:produced_by]-(m:movie) + RETURN m.name + +3. +question: Which movie references the movie "XYZ" via at most 3 hops ? +graph_schema: +{{ +"Edges": {{ + "produced_by": "From movie nodes to producer nodes", + "acts": "From actor nodes to movie nodes", + "references": "From movie nodes to movie nodes", + "own_award": "From actor nodes to award nodes" + }} +}} + +generated_gql: + GRAPH moviedb + MATCH (m:movie)-[:references*1..3]->(:movie {{name="XYZ"}}) + RETURN m.name + +The path quantification syntax [:references*1..3] is wrong. +The verified gql fixes the path quantification syntax: + GRAPH moviedb + MATCH (m:movie)-[:references]->{{1, 3}}(:movie {{name="XYZ"}}) + RETURN m.name +""" + +DEFAULT_GQL_VERIFY_TEMPLATE_PART0 = """ +Given a natual language question, ISO GQL graph query and a graph schema, +validate the query. + +{verify_examples} +""" + +DEFAULT_GQL_VERIFY_TEMPLATE_PART1 = """ +Instructions: +Add missing nodes and edges in the query if required. +Fix the path quantification syntax if required. +Carefully check the syntax. +Fix the query if required. There could be more than one correction. +Optimize if possible. +Do not make changes if not required. +Think in steps. Add the explanation in the output. + +Question : {question} +Input gql: {generated_gql} +Schema: {graph_schema} + +{format_instructions} +""" + +DEFAULT_GQL_VERIFY_TEMPLATE = ( + DEFAULT_GQL_VERIFY_TEMPLATE_PART0.format(verify_examples=VERIFY_EXAMPLES) + + DEFAULT_GQL_VERIFY_TEMPLATE_PART1 +) + +DEFAULT_GQL_FIX_TEMPLATE_PART0 = """ +We generated a ISO GQL query to answer a natural language question. +Question: {question} +However the generated ISO GQL query is not valid. ``` +Input gql: {generated_gql} +``` +The error obtained when executing the query is +``` +{err_msg} +``` +Give me a correct version of the query. +Do not generate the same query as the input gql. +""" + +DEFAULT_GQL_FIX_TEMPLATE_PART1 = """ +Examples of correct query : +{gql_examples}""" + +DEFAULT_GQL_FIX_TEMPLATE_PART2 = """ +Instructions: +Mention the name of the graph at the beginning. +Use only nodes and edge types, and properties included in the schema. +Do not use any node and edge type, or properties not included in the schema. +Do not generate the same query as the input gql. +Schema: {schema} + +Note: +Do not include any explanations or apologies. +Do not prefix query with `gql` +Do not include any backticks. +Start with GRAPH +Output only the query statement. +Do not output any query that tries to modify or delete data. +""" + +DEFAULT_GQL_FIX_TEMPLATE = ( + DEFAULT_GQL_FIX_TEMPLATE_PART0 + + DEFAULT_GQL_FIX_TEMPLATE_PART1.format(gql_examples=GQL_EXAMPLES) + + DEFAULT_GQL_FIX_TEMPLATE_PART2 +) + +SPANNERGRAPH_QA_TEMPLATE = """ +You are a helpful AI assistant. +Create a human readable answer for the for the question. +You should only use the information provided in the context and not use your internal knowledge. +Don't add any information. +Here is an example: + +Question: Which funds own assets over 10M? +Context:[name:ABC Fund, name:Star fund]" +Helpful Answer: ABC Fund and Star fund have assets over 10M. + +Follow this example when generating answers. +If the provided information is empty, say that you don't know the answer. +You are given the following information: +- `Question`: the natural language question from the user +- `Graph Schema`: contains the schema of the graph database +- `Graph Query`: A ISO GQL query equivalent of the question from the user used to extract context from the graph database +- `Context`: The response from the graph database as context +Information: +Question: {question} +Graph Schema: {graph_schema} +Graph Query: {graph_query} +Context: {context} + +Helpful Answer:""" diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py new file mode 100644 index 0000000..bf6551f --- /dev/null +++ b/tests/integration/test_spanner_graph_qa.py @@ -0,0 +1,195 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +from google.cloud import spanner +from langchain.evaluation import load_evaluator +from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship +from langchain_core.documents import Document +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings + +from langchain_google_spanner.graph_qa import SpannerGraphQAChain +from langchain_google_spanner.graph_store import SpannerGraphStore + + +project_id = os.environ["PROJECT_ID"] +instance_id = os.environ["INSTANCE_ID"] +database_id = os.environ["GOOGLE_DATABASE"] +graph_name = os.environ["GRAPH_NAME"] + + +def get_llm(): + llm = ChatVertexAI( + model="gemini-1.5-flash-002", + temperature=0, + ) + return llm + + +def get_evaluator(): + return load_evaluator( + "embedding_distance", + embeddings=VertexAIEmbeddings(model_name="text-embedding-004"), + ) + + +def get_spanner_graph(): + graph = SpannerGraphStore( + instance_id=instance_id, + database_id=database_id, + graph_name=graph_name, + client=spanner.Client(project=project_id), + ) + return graph + + +def load_data(graph: SpannerGraphStore): + graph_documents = [ + GraphDocument( + nodes=[ + Node( + id="Elias Thorne", + type="Person", + properties={ + "name": "Elias Thorne", + "description": "lived in the desert", + }, + ), + Node( + id="Zephyr", + type="Animal", + properties={"name": "Zephyr", "description": "pet falcon"}, + ), + Node( + id="Elara", + type="Person", + properties={ + "name": "Elara", + "description": "resided in the capital city", + }, + ), + Node(id="Desert", type="Location", properties={}), + Node(id="Capital City", type="Location", properties={}), + ], + relationships=[ + Relationship( + source=Node(id="Elias Thorne", type="Person", properties={}), + target=Node(id="Desert", type="Location", properties={}), + type="LivesIn", + properties={}, + ), + Relationship( + source=Node(id="Elias Thorne", type="Person", properties={}), + target=Node(id="Zephyr", type="Animal", properties={}), + type="Owns", + properties={}, + ), + Relationship( + source=Node(id="Elara", type="Person", properties={}), + target=Node(id="Capital City", type="Location", properties={}), + type="LivesIn", + properties={}, + ), + Relationship( + source=Node(id="Elias Thorne", type="Person", properties={}), + target=Node(id="Elara", type="Person", properties={}), + type="Sibling", + properties={}, + ), + ], + source=Document( + metadata={}, + page_content=( + "Elias Thorne lived in the desert. He was a skilled craftsman" + " who worked with sandstone. Elias had a pet falcon named" + " Zephyr. His sister, Elara, resided in the capital city and" + " ran a spice shop. They rarely met due to the distance." + ), + ), + ) + ] + graph.add_graph_documents(graph_documents) + graph.refresh_schema() + + +class TestSpannerGraphQAChain: + + @pytest.fixture + def setup_db_load_data(self): + graph = get_spanner_graph() + load_data(graph) + yield graph + # teardown + graph.cleanup() + + @pytest.fixture + def chain(self, setup_db_load_data): + graph = setup_db_load_data + return SpannerGraphQAChain.from_llm( + get_llm(), + graph=graph, + verbose=True, + return_intermediate_steps=True, + allow_dangerous_requests=True, + ) + + @pytest.fixture + def chain_without_opt_in(self, setup_db_load_data): + graph = setup_db_load_data + return SpannerGraphQAChain.from_llm( + get_llm(), + graph=graph, + verbose=True, + return_intermediate_steps=True, + ) + + def test_spanner_graph_qa_chain_1(self, chain): + question = "Where does Elias Thorne's sibling live?" + response = chain.invoke("query=" + question) + print(response) + + answer = response["result"] + assert ( + get_evaluator().evaluate_strings( + prediction=answer, + reference="Elias Thorne's sibling lives in Capital City.\n", + )["score"] + < 0.1 + ) + + def test_spanner_graph_qa_chain_no_answer(self, chain): + question = "Where does Sarah's sibling live?" + response = chain.invoke("query=" + question) + print(response) + + answer = response["result"] + assert ( + get_evaluator().evaluate_strings( + prediction=answer, + reference="I don't know the answer.\n", + )["score"] + < 0.1 + ) + + def test_spanner_graph_qa_chain_without_opt_in(self, setup_db_load_data): + with pytest.raises(ValueError): + graph = setup_db_load_data + SpannerGraphQAChain.from_llm( + get_llm(), + graph=graph, + verbose=True, + return_intermediate_steps=True, + ) From d8e6640fb1a70bc1db4e68322170f3323aa4d1dd Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 27 Nov 2024 23:26:13 +0000 Subject: [PATCH 02/40] Formatted notebook. Added copyright message to prompts file. --- docs/graph_qa_chain.ipynb | 1330 ++++++++++---------- src/langchain_google_spanner/prompts.py | 14 + tests/integration/test_spanner_graph_qa.py | 4 +- 3 files changed, 684 insertions(+), 664 deletions(-) diff --git a/docs/graph_qa_chain.ipynb b/docs/graph_qa_chain.ipynb index e6a5716..4bc5ab3 100644 --- a/docs/graph_qa_chain.ipynb +++ b/docs/graph_qa_chain.ipynb @@ -1,689 +1,693 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - }, - "colab": { - "provenance": [], - "toc_visible": true - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Google Spanner\n", - "\n", - "> [Spanner](https://cloud.google.com/spanner) is a highly scalable database that combines unlimited scalability with relational semantics, such as secondary indexes, strong consistency, schemas, and SQL providing 99.999% availability in one easy solution.\n", - "\n", - "This notebook goes over how to use `Spanner` for GraphRAG with `SpannerGraphStore` and `SpannerGraphQAChain` class.\n", - "\n", - "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-spanner-python/).\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-spanner-python/blob/main/docs/graph_store.ipynb)" - ], - "metadata": { - "id": "7VBkjcqNNxEd" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Before You Begin\n", - "\n", - "To run this notebook, you will need to do the following:\n", - "\n", - " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", - " * [Enable the Cloud Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com)\n", - " * [Create a Spanner instance](https://cloud.google.com/spanner/docs/create-manage-instances)\n", - " * [Create a Spanner database](https://cloud.google.com/spanner/docs/create-manage-databases)" - ], - "metadata": { - "id": "HEAGYTPgNydh" - } - }, - { - "cell_type": "markdown", - "source": [ - "### 🦜🔗 Library Installation\n", - "The integration lives in its own `langchain-google-spanner` package, so we need to install it." - ], - "metadata": { - "id": "cboPIg-yOcxS" - } - }, - { - "cell_type": "code", - "source": [ - "%pip install --upgrade --quiet langchain-google-spanner langchain-google-vertexai langchain-experimental json-repair pyvis" - ], - "metadata": { - "id": "AOWh6QKYVdDp" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." - ], - "metadata": { - "id": "M7MqpDhkOiP-" - } - }, - { - "cell_type": "code", - "source": [ - "# # Automatically restart kernel after installs so that your environment can access the new packages\n", - "import IPython\n", - "\n", - "app = IPython.Application.instance()\n", - "app.kernel.do_shutdown(True)" - ], - "metadata": { - "id": "xzgVZv0POj17" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### 🔐 Authentication\n", - "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", - "\n", - "* If you are using Colab to run this notebook, use the cell below and continue.\n", - "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." - ], - "metadata": { - "id": "zfIhwIryOls1" - } - }, - { - "cell_type": "code", - "source": [ - "from google.colab import auth\n", - "\n", - "auth.authenticate_user()" - ], - "metadata": { - "id": "EWOkHI7XOna2" - }, - "execution_count": 1, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### ☁ Set Your Google Cloud Project\n", - "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", - "\n", - "If you don't know your project ID, try the following:\n", - "\n", - "* Run `gcloud config list`.\n", - "* Run `gcloud projects list`.\n", - "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." - ], - "metadata": { - "id": "6xHXneICOpsB" - } - }, - { - "cell_type": "code", - "source": [ - "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", - "\n", - "PROJECT_ID = \"google.com:cloud-spanner-demo\" # @param {type:\"string\"}\n", - "\n", - "# Set the project id\n", - "!gcloud config set project {PROJECT_ID}\n", - "%env GOOGLE_CLOUD_PROJECT={PROJECT_ID}" - ], - "metadata": { - "id": "hF0481BGOsS8" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### 💡 API Enablement\n", - "The `langchain-google-spanner` package requires that you [enable the Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com) in your Google Cloud Project." - ], - "metadata": { - "id": "4TiC0RbhOwUu" - } - }, - { - "cell_type": "code", - "source": [ - "# enable Spanner API\n", - "!gcloud services enable spanner.googleapis.com" - ], - "metadata": { - "id": "9f3fJd5eOyRr" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "You must also and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com)." - ], - "metadata": { - "id": "bT_S-jaEOW4P" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Basic Usage" - ], - "metadata": { - "id": "k5pxMMiMOzt7" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Set Spanner database values\n", - "Find your database values, in the [Spanner Instances page](https://console.cloud.google.com/spanner?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." - ], - "metadata": { - "id": "mtDbLU5sO2iA" - } + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "colab": { + "provenance": [], + "toc_visible": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Google Spanner\n", + "\n", + "> [Spanner](https://cloud.google.com/spanner) is a highly scalable database that combines unlimited scalability with relational semantics, such as secondary indexes, strong consistency, schemas, and SQL providing 99.999% availability in one easy solution.\n", + "\n", + "This notebook goes over how to use `Spanner` for GraphRAG with `SpannerGraphStore` and `SpannerGraphQAChain` class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-spanner-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-spanner-python/blob/main/docs/graph_store.ipynb)" + ], + "metadata": { + "id": "7VBkjcqNNxEd" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Before You Begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the Cloud Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com)\n", + " * [Create a Spanner instance](https://cloud.google.com/spanner/docs/create-manage-instances)\n", + " * [Create a Spanner database](https://cloud.google.com/spanner/docs/create-manage-databases)" + ], + "metadata": { + "id": "HEAGYTPgNydh" + } + }, + { + "cell_type": "markdown", + "source": [ + "### 🦜🔗 Library Installation\n", + "The integration lives in its own `langchain-google-spanner` package, so we need to install it." + ], + "metadata": { + "id": "cboPIg-yOcxS" + } + }, + { + "cell_type": "code", + "source": [ + "%pip install --upgrade --quiet langchain-google-spanner langchain-google-vertexai langchain-experimental json-repair pyvis" + ], + "metadata": { + "id": "AOWh6QKYVdDp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ], + "metadata": { + "id": "M7MqpDhkOiP-" + } + }, + { + "cell_type": "code", + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "import IPython\n", + "\n", + "app = IPython.Application.instance()\n", + "app.kernel.do_shutdown(True)" + ], + "metadata": { + "id": "xzgVZv0POj17" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### 🔐 Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ], + "metadata": { + "id": "zfIhwIryOls1" + } + }, + { + "cell_type": "code", + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ], + "metadata": { + "id": "EWOkHI7XOna2" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ], + "metadata": { + "id": "6xHXneICOpsB" + } + }, + { + "cell_type": "code", + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"google.com:cloud-spanner-demo\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}\n", + "%env GOOGLE_CLOUD_PROJECT={PROJECT_ID}" + ], + "metadata": { + "id": "hF0481BGOsS8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### 💡 API Enablement\n", + "The `langchain-google-spanner` package requires that you [enable the Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com) in your Google Cloud Project." + ], + "metadata": { + "id": "4TiC0RbhOwUu" + } + }, + { + "cell_type": "code", + "source": [ + "# enable Spanner API\n", + "!gcloud services enable spanner.googleapis.com" + ], + "metadata": { + "id": "9f3fJd5eOyRr" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "You must also and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com)." + ], + "metadata": { + "id": "bT_S-jaEOW4P" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic Usage" + ], + "metadata": { + "id": "k5pxMMiMOzt7" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Set Spanner database values\n", + "Find your database values, in the [Spanner Instances page](https://console.cloud.google.com/spanner?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." + ], + "metadata": { + "id": "mtDbLU5sO2iA" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "INSTANCE = \"\" # @param {type: \"string\"}\n", + "DATABASE = \"\" # @param {type: \"string\"}\n", + "GRAPH_NAME = \"\" # @param {type: \"string\"}" + ], + "metadata": { + "id": "C-I8VTIcO442" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### SpannerGraphStore\n", + "\n", + "To initialize the `SpannerGraphStore` class you need to provide 3 required arguments and other arguments are optional and only need to pass if it's different from default ones\n", + "\n", + "1. a Spanner instance id;\n", + "2. a Spanner database id belongs to the above instance id;\n", + "3. a Spanner graph name used to create a graph in the above database." + ], + "metadata": { + "id": "kpAv-tpcO_iL" + } + }, + { + "cell_type": "code", + "source": [ + "from langchain_google_spanner import SpannerGraphStore\n", + "\n", + "graph_store = SpannerGraphStore(\n", + " instance_id=INSTANCE,\n", + " database_id=DATABASE,\n", + " graph_name=GRAPH_NAME,\n", + ")" + ], + "metadata": { + "id": "u589YapWQFb8" + }, + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "#### Add Graph Documents to Spanner Graph" + ], + "metadata": { + "id": "G7-Pe2ADQlNJ" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Extract Nodes and Edges from text snippets\n", + "from langchain_core.documents import Document\n", + "from langchain_experimental.graph_transformers import LLMGraphTransformer\n", + "from langchain_google_vertexai import ChatVertexAI\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "\n", + "text_snippets = [\n", + " # Text snippet for students graduting from Veritas University, Computer Science Dept 2017\n", + " \"\"\"\n", + "This was the graduation ceremony of 2017. A wave of jubilant graduates poured out of the\n", + "grand halls of Veritas University, their laughter echoing across the quad. Among them were\n", + "a cohort of exceptional students from the Computer Science department, a group that had\n", + "become known for their collaborative spirit and innovative ideas.\n", + "Leading the pack was Emily Davis, a coding whiz with a passion for cybersecurity, already\n", + "fielding offers from top tech firms. Beside her walked James Rodriguez, a quiet but\n", + "brilliant mind fascinated by artificial intelligence, dreaming of building machines that\n", + "could understand human emotions. Trailing slightly behind, deep in conversation, were\n", + "Sarah Chen and Michael Patel, both aspiring game developers, eager to bring their creative\n", + "visions to life. And then there was Aisha Khan, a social justice advocate who planned to\n", + "use her coding skills to address inequality through technology.\n", + "As they celebrated their achievements, these Veritas University Computer Science graduates\n", + "were ready to embark on diverse paths, each carrying the potential to shape the future of\n", + "technology in their own unique way.\n", + "\"\"\",\n", + " # Text snippet for students graduting from Oakhaven University, Computer Science Dept 2016\n", + " \"\"\"\n", + "The year was 2016, and a palpable buzz filled the air as the graduating class of Oakhaven\n", + "university from Computer science and Engineering department emerged from the Beckman\n", + "Auditorium. Among them was a group of exceptional students, renowned for their\n", + "intellectual curiosity and groundbreaking research.\n", + "At the forefront was Alice Johnson, a gifted programmer with a fascination for quantum\n", + "computing, already collaborating with leading researchers in the field. Beside her\n", + "strode David Kim, a brilliant theorist captivated by the intricacies of cryptography,\n", + "eager to contribute to the development of secure communication systems. Engaged in an\n", + "animated discussion were Maria Rodriguez and Robert Lee, both passionate about robotics\n", + "and determined to push the boundaries of artificial intelligence. And then there was\n", + "Chloe Brown, a visionary with a deep interest in bioinformatics, driven to unlock the\n", + "secrets of the human genome through computational analysis.\n", + "As they celebrated their accomplishments, these graduates, armed with their exceptional\n", + "skills and unwavering determination, were poised to make significant contributions to the world of computing and beyond.\n", + "\"\"\",\n", + " # Text snippet mentions the company Emily Davis founded.\n", + " # The snippet doesn't mention that she is an alumni of Veritas University\n", + " \"\"\"\n", + "Emily Davis, a name synonymous with cybersecurity innovation, turned that passion into a\n", + "thriving business. In the year 2022, Davis founded Ironclad Security, a company that's\n", + "rapidly changing the landscape of cybersecurity solutions.\n", + "\"\"\",\n", + " # Text snippet mentions the company Alice Johnson founded.\n", + " # The snippet doesn't mention that she is an alumni of Oakhaven University.\n", + " \"\"\"\n", + "Alice Johnson had a vision that extended far beyond the classroom. Driven by an insatiable\n", + "curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a\n", + "company poised to revolutionize industries through the power of quantum technology.\n", + "Entangled Solutions distinguishes itself by focusing on practical applications of quantum\n", + "computing.\n", + "\"\"\",\n", + "]\n", + "\n", + "# Create splits for documents\n", + "documents = [Document(page_content=t) for t in text_snippets]\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", + "splits = text_splitter.split_documents(documents)\n", + "llm = ChatVertexAI(model=\"gemini-1.5-flash\", temperature=0)\n", + "llm_transformer = LLMGraphTransformer(\n", + " llm=llm,\n", + " allowed_nodes=[\"College\", \"Deparatment\", \"Person\", \"Year\", \"Company\"],\n", + " allowed_relationships=[\n", + " \"AlumniOf\",\n", + " \"StudiedInDepartment\",\n", + " \"PartOf\",\n", + " \"GraduatedInYear\",\n", + " \"Founded\",\n", + " ],\n", + " node_properties=[\n", + " \"description\",\n", + " ],\n", + ")\n", + "graph_documents = llm_transformer.convert_to_graph_documents(splits)" + ], + "metadata": { + "id": "fP7XNu3aPl5c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Print extracted nodes and edges\n", + "for doc in graph_documents:\n", + " print(doc.source.page_content[:100])\n", + " print(doc.nodes)\n", + " print(doc.relationships)\n", + " print()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "OylyNyv-ZsT2", + "outputId": "e4253d98-ad63-4ea8-a5f1-0e3dac8f6632" + }, + "execution_count": 13, + "outputs": [ { - "cell_type": "code", - "source": [ - "# @title Set Your Values Here { display-mode: \"form\" }\n", - "INSTANCE = \"\" # @param {type: \"string\"}\n", - "DATABASE = \"\" # @param {type: \"string\"}\n", - "GRAPH_NAME = \"\" # @param {type: \"string\"}" - ], - "metadata": { - "id": "C-I8VTIcO442" - }, - "execution_count": 15, - "outputs": [] + "output_type": "stream", + "name": "stdout", + "text": [ + "This was the graduation ceremony of 2017. A wave of jubilant graduates poured out of the\n", + "grand halls\n", + "[Node(id='Veritas University', type='College', properties={'description': 'grand halls'}), Node(id='Computer Science', type='Deparatment', properties={}), Node(id='2017', type='Year', properties={}), Node(id='Emily Davis', type='Person', properties={'description': 'coding whiz with a passion for cybersecurity'}), Node(id='James Rodriguez', type='Person', properties={'description': 'quiet but brilliant mind fascinated by artificial intelligence'}), Node(id='Sarah Chen', type='Person', properties={'description': 'aspiring game developers'}), Node(id='Michael Patel', type='Person', properties={'description': 'aspiring game developers'}), Node(id='Aisha Khan', type='Person', properties={'description': 'social justice advocate'})]\n", + "[Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={})]\n", + "\n", + "visions to life. And then there was Aisha Khan, a social justice advocate who planned to\n", + "use her c\n", + "[Node(id='Veritas University', type='College', properties={}), Node(id='Computer Science', type='Deparatment', properties={}), Node(id='Aisha Khan', type='Person', properties={'description': 'social justice advocate'})]\n", + "[Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={})]\n", + "\n", + "The year was 2016, and a palpable buzz filled the air as the graduating class of Oakhaven\n", + "university\n", + "[Node(id='Oakhaven University', type='College', properties={'description': 'Oakhaven university'}), Node(id='Computer Science And Engineering', type='Deparatment', properties={'description': 'Computer science and Engineering'}), Node(id='2016', type='Year', properties={'description': '2016'}), Node(id='Alice Johnson', type='Person', properties={'description': 'a gifted programmer with a fascination for quantum computing, already collaborating with leading researchers in the field'}), Node(id='David Kim', type='Person', properties={'description': 'a brilliant theorist captivated by the intricacies of cryptography, eager to contribute to the development of secure communication systems'}), Node(id='Maria Rodriguez', type='Person', properties={'description': 'passionate about robotics and determined to push the boundaries of artificial intelligence'}), Node(id='Robert Lee', type='Person', properties={'description': 'passionate about robotics and determined to push the boundaries of artificial intelligence'}), Node(id='Chloe Brown', type='Person', properties={'description': 'a visionary with a deep interest in bioinformatics, driven to unlock the secrets of the human genome through computational analysis'}), Node(id='Beckman Auditorium', type='Deparatment', properties={'description': 'Beckman Auditorium'})]\n", + "[Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='PARTOF', properties={}), Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Oakhaven University', type='College', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Oakhaven University', type='College', properties={}), target=Node(id='Beckman Auditorium', type='Deparatment', properties={}), type='PARTOF', properties={})]\n", + "\n", + "Chloe Brown, a visionary with a deep interest in bioinformatics, driven to unlock the\n", + "secrets of the\n", + "[Node(id='Chloe Brown', type='Person', properties={'description': 'a visionary with a deep interest in bioinformatics, driven to unlock the secrets of the human genome through computational analysis'})]\n", + "[]\n", + "\n", + "Emily Davis, a name synonymous with cybersecurity innovation, turned that passion into a\n", + "thriving bu\n", + "[Node(id='Emily Davis', type='Person', properties={'description': 'a name synonymous with cybersecurity innovation'}), Node(id='Ironclad Security', type='Company', properties={'description': \"a company that's rapidly changing the landscape of cybersecurity solutions\"}), Node(id='2022', type='Year', properties={})]\n", + "[Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Ironclad Security', type='Company', properties={}), type='FOUNDED', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='2022', type='Year', properties={}), type='FOUNDED', properties={})]\n", + "\n", + "Alice Johnson had a vision that extended far beyond the classroom. Driven by an insatiable\n", + "curiosity\n", + "[Node(id='Alice Johnson', type='Person', properties={'description': 'Driven by an insatiable curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a company poised to revolutionize industries through the power of quantum technology.'}), Node(id='Entangled Solutions', type='Company', properties={'description': 'Entangled Solutions distinguishes itself by focusing on practical applications of quantum computing.'})]\n", + "[Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Entangled Solutions', type='Company', properties={}), type='FOUNDED', properties={})]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Load the graph to Spanner Graph database\n", + "# Uncomment the line below, if you want to cleanup from\n", + "# previous iterations.\n", + "# BeWARE - THIS COULD REMOVE DATA FROM YOUR DATABASE !!!\n", + "# graph_store.cleanup()\n", + "\n", + "\n", + "for graph_document in graph_documents:\n", + " graph_store.add_graph_documents([graph_document])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "lMXvOpRbZdau", + "outputId": "26647456-2316-46e3-de43-cfc9845a1050" + }, + "execution_count": 18, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "### SpannerGraphStore\n", - "\n", - "To initialize the `SpannerGraphStore` class you need to provide 3 required arguments and other arguments are optional and only need to pass if it's different from default ones\n", - "\n", - "1. a Spanner instance id;\n", - "2. a Spanner database id belongs to the above instance id;\n", - "3. a Spanner graph name used to create a graph in the above database." - ], - "metadata": { - "id": "kpAv-tpcO_iL" - } + "output_type": "stream", + "name": "stdout", + "text": [ + "Waiting for DDL operations to complete...\n", + "Insert nodes of type `College`...\n", + "Insert nodes of type `Deparatment`...\n", + "Insert nodes of type `Year`...\n", + "Insert nodes of type `Person`...\n", + "Insert edges of type `Person_ALUMNIOF_College`...\n", + "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", + "Insert edges of type `Person_GRADUATEDINYEAR_Year`...\n", + "No schema change required...\n", + "Insert nodes of type `College`...\n", + "Insert nodes of type `Deparatment`...\n", + "Insert nodes of type `Person`...\n", + "Insert edges of type `Person_ALUMNIOF_College`...\n", + "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", + "Waiting for DDL operations to complete...\n", + "Insert nodes of type `College`...\n", + "Insert nodes of type `Deparatment`...\n", + "Insert nodes of type `Year`...\n", + "Insert nodes of type `Person`...\n", + "Insert edges of type `Person_ALUMNIOF_College`...\n", + "Insert edges of type `Deparatment_PARTOF_College`...\n", + "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", + "Insert edges of type `College_GRADUATEDINYEAR_Year`...\n", + "Insert edges of type `Person_GRADUATEDINYEAR_Year`...\n", + "Insert edges of type `College_PARTOF_Deparatment`...\n", + "No schema change required...\n", + "Insert nodes of type `Person`...\n", + "Waiting for DDL operations to complete...\n", + "Insert nodes of type `Person`...\n", + "Insert nodes of type `Company`...\n", + "Insert nodes of type `Year`...\n", + "Insert edges of type `Person_FOUNDED_Company`...\n", + "Insert edges of type `Person_FOUNDED_Year`...\n", + "No schema change required...\n", + "Insert nodes of type `Person`...\n", + "Insert nodes of type `Company`...\n", + "Insert edges of type `Person_FOUNDED_Company`...\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Initialize the Spanner Graph QA Chain\n", + "The Spanner Graph QA Chain takes two parameters, a SpannerGraphStore object and a language model." + ], + "metadata": { + "id": "qlKwtdGN7kaT" + } + }, + { + "cell_type": "code", + "source": [ + "from google.cloud import spanner\n", + "from langchain_google_vertexai import ChatVertexAI\n", + "from IPython.core.display import HTML\n", + "\n", + "# Initialize llm object\n", + "llm = ChatVertexAI(model=\"gemini-1.5-flash-002\", temperature=0)\n", + "\n", + "# Initialize GraphQAChain\n", + "chain = SpannerGraphQAChain.from_llm(\n", + " llm,\n", + " graph=graph_store,\n", + " allow_dangerous_requests=True,\n", + " verbose=True,\n", + " return_intermediate_steps=True,\n", + ")" + ], + "metadata": { + "id": "7yKDAD9s7t7O" + }, + "execution_count": 30, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Run Spanner Graph QA Chain 1\n", + "question = \"Who are the alumni of the college id Veritas University ?\" # @param {type:\"string\"}\n", + "response = chain.invoke(\"query=\" + question)\n", + "response[\"result\"]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 264 }, + "id": "ukKi9wtH_bF1", + "outputId": "61b66dcb-54cf-4620-a097-b4f0d732d1e3" + }, + "execution_count": 33, + "outputs": [ { - "cell_type": "code", - "source": [ - "from langchain_google_spanner import SpannerGraphStore\n", - "\n", - "graph_store = SpannerGraphStore(\n", - " instance_id=INSTANCE,\n", - " database_id=DATABASE,\n", - " graph_name=GRAPH_NAME,\n", - ")" - ], - "metadata": { - "id": "u589YapWQFb8" - }, - "execution_count": 16, - "outputs": [] + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", + "Executing gql:\n", + "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", + "MATCH (p:Person)-[:ALUMNIOF]->(c:College {id: \"Veritas University\"})\n", + "RETURN p.id AS person_id, c.id AS college_id\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'person_id': 'Aisha Khan', 'college_id': 'Veritas University'}, {'person_id': 'Emily Davis', 'college_id': 'Veritas University'}, {'person_id': 'James Rodriguez', 'college_id': 'Veritas University'}, {'person_id': 'Michael Patel', 'college_id': 'Veritas University'}, {'person_id': 'Sarah Chen', 'college_id': 'Veritas University'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] }, { - "cell_type": "markdown", - "source": [ - "#### Add Graph Documents to Spanner Graph" + "output_type": "execute_result", + "data": { + "text/plain": [ + "'Aisha Khan, Emily Davis, James Rodriguez, Michael Patel, and Sarah Chen are alumni of Veritas University.\\n'" ], - "metadata": { - "id": "G7-Pe2ADQlNJ" + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" } + }, + "metadata": {}, + "execution_count": 33 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Run Spanner Graph QA Chain 2\n", + "question = \"List the companies, their founders and the college they attended.\" # @param {type:\"string\"}\n", + "response = chain.invoke(\"query=\" + question)\n", + "response[\"result\"]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 }, + "outputId": "e47d9f63-6769-49bc-b3a3-412c10de5c8a", + "id": "lcBc4tG__7Rm" + }, + "execution_count": 34, + "outputs": [ { - "cell_type": "code", - "source": [ - " # @title Extract Nodes and Edges from text snippets\n", - "from langchain_core.documents import Document\n", - "from langchain_experimental.graph_transformers import LLMGraphTransformer\n", - "from langchain_google_vertexai import ChatVertexAI\n", - "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", - "\n", - "text_snippets = [\n", - "# Text snippet for students graduting from Veritas University, Computer Science Dept 2017\n", - "\n", - "\"\"\"\n", - "This was the graduation ceremony of 2017. A wave of jubilant graduates poured out of the\n", - "grand halls of Veritas University, their laughter echoing across the quad. Among them were\n", - "a cohort of exceptional students from the Computer Science department, a group that had\n", - "become known for their collaborative spirit and innovative ideas.\n", - "Leading the pack was Emily Davis, a coding whiz with a passion for cybersecurity, already\n", - "fielding offers from top tech firms. Beside her walked James Rodriguez, a quiet but\n", - "brilliant mind fascinated by artificial intelligence, dreaming of building machines that\n", - "could understand human emotions. Trailing slightly behind, deep in conversation, were\n", - "Sarah Chen and Michael Patel, both aspiring game developers, eager to bring their creative\n", - "visions to life. And then there was Aisha Khan, a social justice advocate who planned to\n", - "use her coding skills to address inequality through technology.\n", - "As they celebrated their achievements, these Veritas University Computer Science graduates\n", - "were ready to embark on diverse paths, each carrying the potential to shape the future of\n", - "technology in their own unique way.\n", - "\"\"\",\n", - "\n", - "# Text snippet for students graduting from Oakhaven University, Computer Science Dept 2016\n", - "\"\"\"\n", - "The year was 2016, and a palpable buzz filled the air as the graduating class of Oakhaven\n", - "university from Computer science and Engineering department emerged from the Beckman\n", - "Auditorium. Among them was a group of exceptional students, renowned for their\n", - "intellectual curiosity and groundbreaking research.\n", - "At the forefront was Alice Johnson, a gifted programmer with a fascination for quantum\n", - "computing, already collaborating with leading researchers in the field. Beside her\n", - "strode David Kim, a brilliant theorist captivated by the intricacies of cryptography,\n", - "eager to contribute to the development of secure communication systems. Engaged in an\n", - "animated discussion were Maria Rodriguez and Robert Lee, both passionate about robotics\n", - "and determined to push the boundaries of artificial intelligence. And then there was\n", - "Chloe Brown, a visionary with a deep interest in bioinformatics, driven to unlock the\n", - "secrets of the human genome through computational analysis.\n", - "As they celebrated their accomplishments, these graduates, armed with their exceptional\n", - "skills and unwavering determination, were poised to make significant contributions to the world of computing and beyond.\n", - "\"\"\",\n", - "\n", - "# Text snippet mentions the company Emily Davis founded.\n", - "# The snippet doesn't mention that she is an alumni of Veritas University\n", - "\"\"\"\n", - "Emily Davis, a name synonymous with cybersecurity innovation, turned that passion into a\n", - "thriving business. In the year 2022, Davis founded Ironclad Security, a company that's\n", - "rapidly changing the landscape of cybersecurity solutions.\n", - "\"\"\",\n", - "\n", - "# Text snippet mentions the company Alice Johnson founded.\n", - "# The snippet doesn't mention that she is an alumni of Oakhaven University.\n", - "\"\"\"\n", - "Alice Johnson had a vision that extended far beyond the classroom. Driven by an insatiable\n", - "curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a\n", - "company poised to revolutionize industries through the power of quantum technology.\n", - "Entangled Solutions distinguishes itself by focusing on practical applications of quantum\n", - "computing.\n", - "\"\"\"\n", - "]\n", - "\n", - "# Create splits for documents\n", - "documents = [Document(page_content=t) for t in text_snippets]\n", - "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", - "splits = text_splitter.split_documents(documents)\n", - "llm = ChatVertexAI(model=\"gemini-1.5-flash\", temperature=0)\n", - "llm_transformer = LLMGraphTransformer(\n", - " llm=llm,\n", - " allowed_nodes = [\"College\", \"Deparatment\", \"Person\", \"Year\", \"Company\"],\n", - " allowed_relationships = [\"AlumniOf\", \"StudiedInDepartment\", \"PartOf\", \"GraduatedInYear\", \"Founded\"],\n", - " node_properties=[ \"description\", ],\n", - ")\n", - "graph_documents = llm_transformer.convert_to_graph_documents(splits)\n" - ], - "metadata": { - "id": "fP7XNu3aPl5c" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# @title Print extracted nodes and edges\n", - "for doc in graph_documents:\n", - " print(doc.source.page_content[:100])\n", - " print(doc.nodes)\n", - " print(doc.relationships)\n", - " print()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OylyNyv-ZsT2", - "outputId": "e4253d98-ad63-4ea8-a5f1-0e3dac8f6632" - }, - "execution_count": 13, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "This was the graduation ceremony of 2017. A wave of jubilant graduates poured out of the\n", - "grand halls\n", - "[Node(id='Veritas University', type='College', properties={'description': 'grand halls'}), Node(id='Computer Science', type='Deparatment', properties={}), Node(id='2017', type='Year', properties={}), Node(id='Emily Davis', type='Person', properties={'description': 'coding whiz with a passion for cybersecurity'}), Node(id='James Rodriguez', type='Person', properties={'description': 'quiet but brilliant mind fascinated by artificial intelligence'}), Node(id='Sarah Chen', type='Person', properties={'description': 'aspiring game developers'}), Node(id='Michael Patel', type='Person', properties={'description': 'aspiring game developers'}), Node(id='Aisha Khan', type='Person', properties={'description': 'social justice advocate'})]\n", - "[Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='James Rodriguez', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Sarah Chen', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Michael Patel', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='2017', type='Year', properties={}), type='GRADUATEDINYEAR', properties={})]\n", - "\n", - "visions to life. And then there was Aisha Khan, a social justice advocate who planned to\n", - "use her c\n", - "[Node(id='Veritas University', type='College', properties={}), Node(id='Computer Science', type='Deparatment', properties={}), Node(id='Aisha Khan', type='Person', properties={'description': 'social justice advocate'})]\n", - "[Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Veritas University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Aisha Khan', type='Person', properties={}), target=Node(id='Computer Science', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={})]\n", - "\n", - "The year was 2016, and a palpable buzz filled the air as the graduating class of Oakhaven\n", - "university\n", - "[Node(id='Oakhaven University', type='College', properties={'description': 'Oakhaven university'}), Node(id='Computer Science And Engineering', type='Deparatment', properties={'description': 'Computer science and Engineering'}), Node(id='2016', type='Year', properties={'description': '2016'}), Node(id='Alice Johnson', type='Person', properties={'description': 'a gifted programmer with a fascination for quantum computing, already collaborating with leading researchers in the field'}), Node(id='David Kim', type='Person', properties={'description': 'a brilliant theorist captivated by the intricacies of cryptography, eager to contribute to the development of secure communication systems'}), Node(id='Maria Rodriguez', type='Person', properties={'description': 'passionate about robotics and determined to push the boundaries of artificial intelligence'}), Node(id='Robert Lee', type='Person', properties={'description': 'passionate about robotics and determined to push the boundaries of artificial intelligence'}), Node(id='Chloe Brown', type='Person', properties={'description': 'a visionary with a deep interest in bioinformatics, driven to unlock the secrets of the human genome through computational analysis'}), Node(id='Beckman Auditorium', type='Deparatment', properties={'description': 'Beckman Auditorium'})]\n", - "[Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='ALUMNIOF', properties={}), Relationship(source=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), target=Node(id='Oakhaven University', type='College', properties={}), type='PARTOF', properties={}), Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='Computer Science And Engineering', type='Deparatment', properties={}), type='STUDIEDINDEPARTMENT', properties={}), Relationship(source=Node(id='Oakhaven University', type='College', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='David Kim', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Maria Rodriguez', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Robert Lee', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Chloe Brown', type='Person', properties={}), target=Node(id='2016', type='Year', properties={}), type='GRADUATEDINYEAR', properties={}), Relationship(source=Node(id='Oakhaven University', type='College', properties={}), target=Node(id='Beckman Auditorium', type='Deparatment', properties={}), type='PARTOF', properties={})]\n", - "\n", - "Chloe Brown, a visionary with a deep interest in bioinformatics, driven to unlock the\n", - "secrets of the\n", - "[Node(id='Chloe Brown', type='Person', properties={'description': 'a visionary with a deep interest in bioinformatics, driven to unlock the secrets of the human genome through computational analysis'})]\n", - "[]\n", - "\n", - "Emily Davis, a name synonymous with cybersecurity innovation, turned that passion into a\n", - "thriving bu\n", - "[Node(id='Emily Davis', type='Person', properties={'description': 'a name synonymous with cybersecurity innovation'}), Node(id='Ironclad Security', type='Company', properties={'description': \"a company that's rapidly changing the landscape of cybersecurity solutions\"}), Node(id='2022', type='Year', properties={})]\n", - "[Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='Ironclad Security', type='Company', properties={}), type='FOUNDED', properties={}), Relationship(source=Node(id='Emily Davis', type='Person', properties={}), target=Node(id='2022', type='Year', properties={}), type='FOUNDED', properties={})]\n", - "\n", - "Alice Johnson had a vision that extended far beyond the classroom. Driven by an insatiable\n", - "curiosity\n", - "[Node(id='Alice Johnson', type='Person', properties={'description': 'Driven by an insatiable curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a company poised to revolutionize industries through the power of quantum technology.'}), Node(id='Entangled Solutions', type='Company', properties={'description': 'Entangled Solutions distinguishes itself by focusing on practical applications of quantum computing.'})]\n", - "[Relationship(source=Node(id='Alice Johnson', type='Person', properties={}), target=Node(id='Entangled Solutions', type='Company', properties={}), type='FOUNDED', properties={})]\n", - "\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "# @title Load the graph to Spanner Graph database\n", - "# Uncomment the line below, if you want to cleanup from\n", - "# previous iterations.\n", - "# BeWARE - THIS COULD REMOVE DATA FROM YOUR DATABASE !!!\n", - "# graph_store.cleanup()\n", - "\n", - "\n", - "for graph_document in graph_documents:\n", - " graph_store.add_graph_documents([graph_document])" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lMXvOpRbZdau", - "outputId": "26647456-2316-46e3-de43-cfc9845a1050" - }, - "execution_count": 18, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Waiting for DDL operations to complete...\n", - "Insert nodes of type `College`...\n", - "Insert nodes of type `Deparatment`...\n", - "Insert nodes of type `Year`...\n", - "Insert nodes of type `Person`...\n", - "Insert edges of type `Person_ALUMNIOF_College`...\n", - "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", - "Insert edges of type `Person_GRADUATEDINYEAR_Year`...\n", - "No schema change required...\n", - "Insert nodes of type `College`...\n", - "Insert nodes of type `Deparatment`...\n", - "Insert nodes of type `Person`...\n", - "Insert edges of type `Person_ALUMNIOF_College`...\n", - "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", - "Waiting for DDL operations to complete...\n", - "Insert nodes of type `College`...\n", - "Insert nodes of type `Deparatment`...\n", - "Insert nodes of type `Year`...\n", - "Insert nodes of type `Person`...\n", - "Insert edges of type `Person_ALUMNIOF_College`...\n", - "Insert edges of type `Deparatment_PARTOF_College`...\n", - "Insert edges of type `Person_STUDIEDINDEPARTMENT_Deparatment`...\n", - "Insert edges of type `College_GRADUATEDINYEAR_Year`...\n", - "Insert edges of type `Person_GRADUATEDINYEAR_Year`...\n", - "Insert edges of type `College_PARTOF_Deparatment`...\n", - "No schema change required...\n", - "Insert nodes of type `Person`...\n", - "Waiting for DDL operations to complete...\n", - "Insert nodes of type `Person`...\n", - "Insert nodes of type `Company`...\n", - "Insert nodes of type `Year`...\n", - "Insert edges of type `Person_FOUNDED_Company`...\n", - "Insert edges of type `Person_FOUNDED_Year`...\n", - "No schema change required...\n", - "Insert nodes of type `Person`...\n", - "Insert nodes of type `Company`...\n", - "Insert edges of type `Person_FOUNDED_Company`...\n" - ] - } - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", + "Executing gql:\n", + "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", + "MATCH (p:Person)-[:FOUNDED]->(c:Company), (p)-[:ALUMNIOF]->(cl:College)\n", + "RETURN c.id AS company_id, c.description AS company_description, p.id AS founder_id, p.description AS founder_description, cl.id AS college_id, cl.description AS college_description\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'company_id': 'Entangled Solutions', 'company_description': 'Entangled Solutions distinguishes itself by focusing on practical applications of quantum computing.', 'founder_id': 'Alice Johnson', 'founder_description': 'Driven by an insatiable curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a company poised to revolutionize industries through the power of quantum technology.', 'college_id': 'Oakhaven University', 'college_description': 'Oakhaven university'}, {'company_id': 'Ironclad Security', 'company_description': \"a company that's rapidly changing the landscape of cybersecurity solutions\", 'founder_id': 'Emily Davis', 'founder_description': 'a name synonymous with cybersecurity innovation', 'college_id': 'Veritas University', 'college_description': 'grand halls'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] }, { - "cell_type": "markdown", - "source": [ - "### Initialize the Spanner Graph QA Chain\n", - "The Spanner Graph QA Chain takes two parameters, a SpannerGraphStore object and a language model." + "output_type": "execute_result", + "data": { + "text/plain": [ + "'Entangled Solutions, founded by Alice Johnson who attended Oakhaven University, focuses on practical applications of quantum computing. Ironclad Security, founded by Emily Davis who attended Veritas University, is rapidly changing the landscape of cybersecurity solutions.\\n'" ], - "metadata": { - "id": "qlKwtdGN7kaT" + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" } + }, + "metadata": {}, + "execution_count": 34 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Run Spanner Graph QA Chain 3\n", + "question = \"Which companies were founded by alumni of college id Veritas University ? Who were the founders ?\" # @param {type:\"string\"}\n", + "response = chain.invoke(\"query=\" + question)\n", + "response[\"result\"]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 264 }, + "outputId": "cb40179e-bcec-4399-df9d-a114e02b33f9", + "id": "e6djmq1NAGOM" + }, + "execution_count": 35, + "outputs": [ { - "cell_type": "code", - "source": [ - "from google.cloud import spanner\n", - "from langchain_google_vertexai import ChatVertexAI\n", - "from IPython.core.display import HTML\n", - "\n", - "# Initialize llm object\n", - "llm = ChatVertexAI(model=\"gemini-1.5-flash-002\", temperature=0)\n", - "\n", - "# Initialize GraphQAChain\n", - "chain = SpannerGraphQAChain.from_llm(\n", - " llm,\n", - " graph=graph_store,\n", - " allow_dangerous_requests=True,\n", - " verbose=True,\n", - " return_intermediate_steps=True\n", - ")" - ], - "metadata": { - "id": "7yKDAD9s7t7O" - }, - "execution_count": 30, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# @title Run Spanner Graph QA Chain 1\n", - "question = \"Who are the alumni of the college id Veritas University ?\" # @param {type:\"string\"}\n", - "response = chain.invoke(\"query=\" + question)\n", - "response[\"result\"]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 264 - }, - "id": "ukKi9wtH_bF1", - "outputId": "61b66dcb-54cf-4620-a097-b4f0d732d1e3" - }, - "execution_count": 33, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", - "Executing gql:\n", - "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", - "MATCH (p:Person)-[:ALUMNIOF]->(c:College {id: \"Veritas University\"})\n", - "RETURN p.id AS person_id, c.id AS college_id\u001b[0m\n", - "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'person_id': 'Aisha Khan', 'college_id': 'Veritas University'}, {'person_id': 'Emily Davis', 'college_id': 'Veritas University'}, {'person_id': 'James Rodriguez', 'college_id': 'Veritas University'}, {'person_id': 'Michael Patel', 'college_id': 'Veritas University'}, {'person_id': 'Sarah Chen', 'college_id': 'Veritas University'}]\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'Aisha Khan, Emily Davis, James Rodriguez, Michael Patel, and Sarah Chen are alumni of Veritas University.\\n'" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 33 - } - ] - }, - { - "cell_type": "code", - "source": [ - "# @title Run Spanner Graph QA Chain 2\n", - "question = \"List the companies, their founders and the college they attended.\" # @param {type:\"string\"}\n", - "response = chain.invoke(\"query=\" + question)\n", - "response[\"result\"]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 281 - }, - "outputId": "e47d9f63-6769-49bc-b3a3-412c10de5c8a", - "id": "lcBc4tG__7Rm" - }, - "execution_count": 34, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", - "Executing gql:\n", - "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", - "MATCH (p:Person)-[:FOUNDED]->(c:Company), (p)-[:ALUMNIOF]->(cl:College)\n", - "RETURN c.id AS company_id, c.description AS company_description, p.id AS founder_id, p.description AS founder_description, cl.id AS college_id, cl.description AS college_description\u001b[0m\n", - "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'company_id': 'Entangled Solutions', 'company_description': 'Entangled Solutions distinguishes itself by focusing on practical applications of quantum computing.', 'founder_id': 'Alice Johnson', 'founder_description': 'Driven by an insatiable curiosity about the potential of quantum mechanics, she founded Entangled Solutions, a company poised to revolutionize industries through the power of quantum technology.', 'college_id': 'Oakhaven University', 'college_description': 'Oakhaven university'}, {'company_id': 'Ironclad Security', 'company_description': \"a company that's rapidly changing the landscape of cybersecurity solutions\", 'founder_id': 'Emily Davis', 'founder_description': 'a name synonymous with cybersecurity innovation', 'college_id': 'Veritas University', 'college_description': 'grand halls'}]\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'Entangled Solutions, founded by Alice Johnson who attended Oakhaven University, focuses on practical applications of quantum computing. Ironclad Security, founded by Emily Davis who attended Veritas University, is rapidly changing the landscape of cybersecurity solutions.\\n'" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 34 - } - ] - }, - { - "cell_type": "code", - "source": [ - "# @title Run Spanner Graph QA Chain 3\n", - "question = \"Which companies were founded by alumni of college id Veritas University ? Who were the founders ?\" # @param {type:\"string\"}\n", - "response = chain.invoke(\"query=\" + question)\n", - "response[\"result\"]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 264 - }, - "outputId": "cb40179e-bcec-4399-df9d-a114e02b33f9", - "id": "e6djmq1NAGOM" - }, - "execution_count": 35, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", - "Executing gql:\n", - "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", - "MATCH (c:College {id: \"Veritas University\"})<-[:ALUMNIOF]-(p:Person)-[:FOUNDED]->(co:Company)\n", - "RETURN co.id AS company_id, co.description AS company_description, p.id AS founder_id, p.description AS founder_description\u001b[0m\n", - "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'company_id': 'Ironclad Security', 'company_description': \"a company that's rapidly changing the landscape of cybersecurity solutions\", 'founder_id': 'Emily Davis', 'founder_description': 'a name synonymous with cybersecurity innovation'}]\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "\"Ironclad Security, a company that's rapidly changing the landscape of cybersecurity solutions, was founded by Emily Davis, a name synonymous with cybersecurity innovation.\\n\"" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 35 - } - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SpannerGraphQAChain chain...\u001b[0m\n", + "Executing gql:\n", + "\u001b[32;1m\u001b[1;3mGRAPH graph_demo_2\n", + "MATCH (c:College {id: \"Veritas University\"})<-[:ALUMNIOF]-(p:Person)-[:FOUNDED]->(co:Company)\n", + "RETURN co.id AS company_id, co.description AS company_description, p.id AS founder_id, p.description AS founder_description\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'company_id': 'Ironclad Security', 'company_description': \"a company that's rapidly changing the landscape of cybersecurity solutions\", 'founder_id': 'Emily Davis', 'founder_description': 'a name synonymous with cybersecurity innovation'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] }, { - "cell_type": "markdown", - "source": [ - "#### Clean up the graph\n", - "\n", - "> USE IT WITH CAUTION!\n", - "\n", - "Clean up all the nodes/edges in your graph and remove your graph definition." + "output_type": "execute_result", + "data": { + "text/plain": [ + "\"Ironclad Security, a company that's rapidly changing the landscape of cybersecurity solutions, was founded by Emily Davis, a name synonymous with cybersecurity innovation.\\n\"" ], - "metadata": { - "id": "pM7TmfI0TEFy" + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" } - }, - { - "cell_type": "code", - "source": [ - "graph_store.cleanup()" - ], - "metadata": { - "id": "UQWq4-sITOgl" - }, - "execution_count": null, - "outputs": [] + }, + "metadata": {}, + "execution_count": 35 } - ] + ] + }, + { + "cell_type": "markdown", + "source": [ + "#### Clean up the graph\n", + "\n", + "> USE IT WITH CAUTION!\n", + "\n", + "Clean up all the nodes/edges in your graph and remove your graph definition." + ], + "metadata": { + "id": "pM7TmfI0TEFy" + } + }, + { + "cell_type": "code", + "source": [ + "graph_store.cleanup()" + ], + "metadata": { + "id": "UQWq4-sITOgl" + }, + "execution_count": null, + "outputs": [] + } + ] } \ No newline at end of file diff --git a/src/langchain_google_spanner/prompts.py b/src/langchain_google_spanner/prompts.py index 86202b4..c118e33 100644 --- a/src/langchain_google_spanner/prompts.py +++ b/src/langchain_google_spanner/prompts.py @@ -1,3 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + GQL_EXAMPLES = """ The following query in backtick matches all persons in the graph FinGraph whose birthday is before 1990-01-10 and diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index bf6551f..1ac7e3e 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -28,7 +28,6 @@ project_id = os.environ["PROJECT_ID"] instance_id = os.environ["INSTANCE_ID"] database_id = os.environ["GOOGLE_DATABASE"] -graph_name = os.environ["GRAPH_NAME"] def get_llm(): @@ -47,6 +46,8 @@ def get_evaluator(): def get_spanner_graph(): + suffix = random_string(num_char=5, exclude_whitespaces=True) + graph_name = "test_graph{}".format(suffix) graph = SpannerGraphStore( instance_id=instance_id, database_id=database_id, @@ -133,6 +134,7 @@ def setup_db_load_data(self): load_data(graph) yield graph # teardown + print(graph.get_schema) graph.cleanup() @pytest.fixture From 7a27522fbadf2565683ff133093b35a2ef3da337 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 27 Nov 2024 23:37:07 +0000 Subject: [PATCH 03/40] Add missing imports for random graph name --- tests/integration/test_spanner_graph_qa.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index 1ac7e3e..3d4f155 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import random +import string import pytest from google.cloud import spanner @@ -30,6 +32,15 @@ database_id = os.environ["GOOGLE_DATABASE"] +def random_string(num_char=5, exclude_whitespaces=False): + return "".join( + random.choice( + string.ascii_letters + ("" if exclude_whitespaces else string.whitespace) + ) + for _ in range(num_char) + ) + + def get_llm(): llm = ChatVertexAI( model="gemini-1.5-flash-002", From e0bbdf1d72873dce0a7332f112038f5a44f586ef Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 28 Nov 2024 00:48:37 +0000 Subject: [PATCH 04/40] Make input table name randomized in integration tests to avoid name collision for tests running parallely from different python environments --- tests/integration/test_spanner_graph_qa.py | 43 ++++++++++++++-------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index 3d4f155..c7fed55 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -32,7 +32,7 @@ database_id = os.environ["GOOGLE_DATABASE"] -def random_string(num_char=5, exclude_whitespaces=False): +def random_string(num_char=3, exclude_whitespaces=False): return "".join( random.choice( string.ascii_letters + ("" if exclude_whitespaces else string.whitespace) @@ -57,7 +57,7 @@ def get_evaluator(): def get_spanner_graph(): - suffix = random_string(num_char=5, exclude_whitespaces=True) + suffix = random_string(num_char=3, exclude_whitespaces=True) graph_name = "test_graph{}".format(suffix) graph = SpannerGraphStore( instance_id=instance_id, @@ -69,12 +69,13 @@ def get_spanner_graph(): def load_data(graph: SpannerGraphStore): + type_suffix = "_" + random_string(num_char=3, exclude_whitespaces=True) graph_documents = [ GraphDocument( nodes=[ Node( id="Elias Thorne", - type="Person", + type="Person" + type_suffix, properties={ "name": "Elias Thorne", "description": "lived in the desert", @@ -82,42 +83,54 @@ def load_data(graph: SpannerGraphStore): ), Node( id="Zephyr", - type="Animal", + type="Animal" + type_suffix, properties={"name": "Zephyr", "description": "pet falcon"}, ), Node( id="Elara", - type="Person", + type="Person" + type_suffix, properties={ "name": "Elara", "description": "resided in the capital city", }, ), - Node(id="Desert", type="Location", properties={}), - Node(id="Capital City", type="Location", properties={}), + Node(id="Desert", type="Location" + type_suffix, properties={}), + Node(id="Capital City", type="Location" + type_suffix, properties={}), ], relationships=[ Relationship( - source=Node(id="Elias Thorne", type="Person", properties={}), - target=Node(id="Desert", type="Location", properties={}), + source=Node( + id="Elias Thorne", type="Person" + type_suffix, properties={} + ), + target=Node( + id="Desert", type="Location" + type_suffix, properties={} + ), type="LivesIn", properties={}, ), Relationship( - source=Node(id="Elias Thorne", type="Person", properties={}), - target=Node(id="Zephyr", type="Animal", properties={}), + source=Node( + id="Elias Thorne", type="Person" + type_suffix, properties={} + ), + target=Node( + id="Zephyr", type="Animal" + type_suffix, properties={} + ), type="Owns", properties={}, ), Relationship( - source=Node(id="Elara", type="Person", properties={}), - target=Node(id="Capital City", type="Location", properties={}), + source=Node(id="Elara", type="Person" + type_suffix, properties={}), + target=Node( + id="Capital City", type="Location" + type_suffix, properties={} + ), type="LivesIn", properties={}, ), Relationship( - source=Node(id="Elias Thorne", type="Person", properties={}), - target=Node(id="Elara", type="Person", properties={}), + source=Node( + id="Elias Thorne", type="Person" + type_suffix, properties={} + ), + target=Node(id="Elara", type="Person" + type_suffix, properties={}), type="Sibling", properties={}, ), From 7f578bcf7a3ed60ff4177d895f23703490d8195e Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Mon, 2 Dec 2024 23:36:14 +0000 Subject: [PATCH 05/40] Provide timeout to graph cleanup --- src/langchain_google_spanner/graph_store.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index 20ec746..adbdac3 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -1025,7 +1025,7 @@ def _add_edges( columns.append(ElementSchema.TARGET_NODE_KEY_COLUMN_NAME) return name, columns, rows - def cleanup(self): + def cleanup(self, timeout: int = 60): """Removes all data from your Spanner Graph. USE IT WITH CAUTION! @@ -1038,18 +1038,21 @@ def cleanup(self): "DROP PROPERTY GRAPH IF EXISTS {}".format( to_identifier(self.schema.graph_name) ) - ] + ], + {timeout: 300}, ) self.impl.apply_ddls( [ "DROP TABLE IF EXISTS {}".format(to_identifier(edge.base_table_name)) for edge in self.schema.edges.values() - ] + ], + {timeout: 300}, ) self.impl.apply_ddls( [ "DROP TABLE IF EXISTS {}".format(to_identifier(node.base_table_name)) for node in self.schema.nodes.values() - ] + ], + {timeout: 300}, ) self.schema = SpannerGraphSchema(self.schema.graph_name) From 3650d892571a2b305a09df0fc58b900fa47548db Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Tue, 3 Dec 2024 00:58:02 +0000 Subject: [PATCH 06/40] Make default timeout of 300 secs for ddl application --- src/langchain_google_spanner/graph_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index adbdac3..4f440b7 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -834,7 +834,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None: op = self.database.update_ddl(ddl_statements=ddls) print("Waiting for DDL operations to complete...") - return op.result(options.get("timeout", 60)) + return op.result(options.get("timeout", 300)) def insert_or_update( self, table: str, columns: List[str], values: List[List[Any]] From 63b3508c284975c9acb323f17f08330a37291f30 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Tue, 3 Dec 2024 06:27:50 +0000 Subject: [PATCH 07/40] Increase timeout of integration test --- integration.cloudbuild.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 5e96b0e..578ad75 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -33,7 +33,7 @@ steps: - "GOOGLE_DATABASE=${_GOOGLE_DATABASE}" - "PG_DATABASE=${_PG_DATABASE}" -timeout: "4800s" +timeout: "6000s" substitutions: _INSTANCE_ID: test-instance _GOOGLE_DATABASE: test-google-db From b9f718c8da39e7b749ada7bac96cce608df5846b Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Tue, 3 Dec 2024 06:28:48 +0000 Subject: [PATCH 08/40] Change integration test timeout --- integration.cloudbuild.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 578ad75..538325c 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -33,7 +33,7 @@ steps: - "GOOGLE_DATABASE=${_GOOGLE_DATABASE}" - "PG_DATABASE=${_PG_DATABASE}" -timeout: "6000s" +timeout: "7200s" substitutions: _INSTANCE_ID: test-instance _GOOGLE_DATABASE: test-google-db From 95768fce0098d0cfb1cb7981d3bdb75b45ab565d Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Tue, 3 Dec 2024 21:55:56 +0000 Subject: [PATCH 09/40] Minor formatting fixes --- src/langchain_google_spanner/__init__.py | 2 +- tests/integration/test_spanner_graph_qa.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/langchain_google_spanner/__init__.py b/src/langchain_google_spanner/__init__.py index 5f5b5ae..28c2dd1 100644 --- a/src/langchain_google_spanner/__init__.py +++ b/src/langchain_google_spanner/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory -from langchain_google_spanner.graph_store import SpannerGraphStore from langchain_google_spanner.graph_qa import SpannerGraphQAChain +from langchain_google_spanner.graph_store import SpannerGraphStore from langchain_google_spanner.vector_store import ( DistanceStrategy, QueryParameters, diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index c7fed55..68ee30f 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -26,7 +26,6 @@ from langchain_google_spanner.graph_qa import SpannerGraphQAChain from langchain_google_spanner.graph_store import SpannerGraphStore - project_id = os.environ["PROJECT_ID"] instance_id = os.environ["INSTANCE_ID"] database_id = os.environ["GOOGLE_DATABASE"] From 638cdeb383979f3d23ab6658273039b81d1dae85 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Tue, 3 Dec 2024 23:43:15 +0000 Subject: [PATCH 10/40] Make the ddl operations test fixture scoped for the module --- tests/integration/test_spanner_graph_qa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index 68ee30f..a4b0980 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -151,7 +151,7 @@ def load_data(graph: SpannerGraphStore): class TestSpannerGraphQAChain: - @pytest.fixture + @pytest.fixture(scope="module") def setup_db_load_data(self): graph = get_spanner_graph() load_data(graph) From 88aa4c101e91737c3a00c379d88035a685e67fc2 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 4 Dec 2024 00:58:31 +0000 Subject: [PATCH 11/40] Addressed review comments --- requirements.txt | 1 - src/langchain_google_spanner/graph_qa.py | 3 +-- src/langchain_google_spanner/graph_store.py | 14 ++++++-------- src/langchain_google_spanner/prompts.py | 10 +++++----- tests/integration/test_spanner_graph_qa.py | 10 ++++------ 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3af70b2..3a65981 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ google-cloud-spanner==3.49.1 langchain-core==0.3.9 langchain-community==0.3.1 -langchain-experimental==0.3.2 langchain_google_vertexai pydantic==2.9.1 diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index 14d546f..f5fbf3d 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -193,8 +193,7 @@ def output_keys(self) -> List[str]: :meta private: """ - _output_keys = [self.output_key] - return _output_keys + return [self.output_key] @classmethod def from_llm( diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index 4f440b7..d3e03a0 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -28,6 +28,7 @@ from .type_utils import TypeUtility MUTATION_BATCH_SIZE = 1000 +DEFAULT_DDL_TIMEOUT = 300 class NodeWrapper(object): @@ -834,7 +835,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None: op = self.database.update_ddl(ddl_statements=ddls) print("Waiting for DDL operations to complete...") - return op.result(options.get("timeout", 300)) + return op.result(options.get("timeout", DEFAULT_DDL_TIMEOUT)) def insert_or_update( self, table: str, columns: List[str], values: List[List[Any]] @@ -1025,7 +1026,7 @@ def _add_edges( columns.append(ElementSchema.TARGET_NODE_KEY_COLUMN_NAME) return name, columns, rows - def cleanup(self, timeout: int = 60): + def cleanup(self): """Removes all data from your Spanner Graph. USE IT WITH CAUTION! @@ -1038,21 +1039,18 @@ def cleanup(self, timeout: int = 60): "DROP PROPERTY GRAPH IF EXISTS {}".format( to_identifier(self.schema.graph_name) ) - ], - {timeout: 300}, + ] ) self.impl.apply_ddls( [ "DROP TABLE IF EXISTS {}".format(to_identifier(edge.base_table_name)) for edge in self.schema.edges.values() - ], - {timeout: 300}, + ] ) self.impl.apply_ddls( [ "DROP TABLE IF EXISTS {}".format(to_identifier(node.base_table_name)) for node in self.schema.nodes.values() - ], - {timeout: 300}, + ] ) self.schema = SpannerGraphSchema(self.schema.graph_name) diff --git a/src/langchain_google_spanner/prompts.py b/src/langchain_google_spanner/prompts.py index c118e33..29f8286 100644 --- a/src/langchain_google_spanner/prompts.py +++ b/src/langchain_google_spanner/prompts.py @@ -50,7 +50,7 @@ """ DEFAULT_GQL_TEMPLATE_PART0 = """ -Create an ISO GQL query for the question using the schema. +Create an Spanner Graph GQL query for the question using the schema. {gql_examples} """ @@ -148,7 +148,7 @@ """ DEFAULT_GQL_VERIFY_TEMPLATE_PART0 = """ -Given a natual language question, ISO GQL graph query and a graph schema, +Given a natual language question, Spanner Graph GQL graph query and a graph schema, validate the query. {verify_examples} @@ -177,9 +177,9 @@ ) DEFAULT_GQL_FIX_TEMPLATE_PART0 = """ -We generated a ISO GQL query to answer a natural language question. +We generated a Spanner Graph GQL query to answer a natural language question. Question: {question} -However the generated ISO GQL query is not valid. ``` +However the generated Spanner Graph GQL query is not valid. ``` Input gql: {generated_gql} ``` The error obtained when executing the query is @@ -233,7 +233,7 @@ You are given the following information: - `Question`: the natural language question from the user - `Graph Schema`: contains the schema of the graph database -- `Graph Query`: A ISO GQL query equivalent of the question from the user used to extract context from the graph database +- `Graph Query`: A Spanner Graph GQL query equivalent of the question from the user used to extract context from the graph database - `Context`: The response from the graph database as context Information: Question: {question} diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index a4b0980..fe1dbca 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -31,11 +31,9 @@ database_id = os.environ["GOOGLE_DATABASE"] -def random_string(num_char=3, exclude_whitespaces=False): +def random_string(num_char=3): return "".join( - random.choice( - string.ascii_letters + ("" if exclude_whitespaces else string.whitespace) - ) + random.choice(string.ascii_letters) for _ in range(num_char) ) @@ -56,7 +54,7 @@ def get_evaluator(): def get_spanner_graph(): - suffix = random_string(num_char=3, exclude_whitespaces=True) + suffix = random_string(num_char=3) graph_name = "test_graph{}".format(suffix) graph = SpannerGraphStore( instance_id=instance_id, @@ -68,7 +66,7 @@ def get_spanner_graph(): def load_data(graph: SpannerGraphStore): - type_suffix = "_" + random_string(num_char=3, exclude_whitespaces=True) + type_suffix = "_" + random_string(num_char=3) graph_documents = [ GraphDocument( nodes=[ From a211728ffa28d3e2256f8a30d960223f9bbea04a Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 4 Dec 2024 06:15:10 +0000 Subject: [PATCH 12/40] Addressed a few other review comments. --- src/langchain_google_spanner/graph_qa.py | 34 ++++++++++++++++++++++ tests/integration/test_spanner_graph_qa.py | 5 +--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index f5fbf3d..012cf18 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -84,6 +84,8 @@ def fix_gql_syntax(query: str) -> str: query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]->", r"-[\1:\2]->{\3}", query) query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"<-[\1:\2]-{\3,\4}", query) query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\]-", r"<-[\1:\2]-{\3}", query) + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"-[\1:\2]-{\3,\4}", query) + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]-", r"-[\1:\2]-{\3}", query) return query @@ -290,6 +292,37 @@ def execute_with_retry( intermediate_steps: List, question: str, gql_query: str, + ) -> tuple[str, List[Any]]: + retries = 0 + while retries <= self.max_gql_fix_retries: + try: + intermediate_steps.append({"generated_query": gql_query}) + return gql_query, self.execute_query(_run_manager, gql_query) + except Exception as e: + err_msg = str(e) + self.log_invalid_query(_run_manager, gql_query, err_msg) + intermediate_steps.pop() + intermediate_steps.append({"query_failed_" + str(retries): gql_query}) + fix_chain_result = self.gql_fix_chain.invoke( + { + "question": question, + "err_msg": err_msg, + "generated_gql": gql_query, + "schema": self.graph.get_schema, + } + ) + gql_query = extract_gql(fix_chain_result) + finally: + retries += 1 + + raise ValueError("The generated gql query is invalid") + + def execute_with_retry_bkp( + self, + _run_manager: CallbackManagerForChainRun, + intermediate_steps: List, + question: str, + gql_query: str, ) -> tuple[str, List[Any]]: try: intermediate_steps.append({"generated_query": gql_query}) @@ -322,6 +355,7 @@ def execute_with_retry( gql_query = new_gql_query err_msg = str(e) self.log_invalid_query(_run_manager, gql_query, err_msg) + raise ValueError("The generated gql query is invalid") def log_invalid_query( diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index fe1dbca..8bac7b8 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -32,10 +32,7 @@ def random_string(num_char=3): - return "".join( - random.choice(string.ascii_letters) - for _ in range(num_char) - ) + return "".join(random.choice(string.ascii_letters) for _ in range(num_char)) def get_llm(): From 80338fc323e20580a467bf49870ce14430694d28 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 4 Dec 2024 06:22:37 +0000 Subject: [PATCH 13/40] Remove unused function --- src/langchain_google_spanner/graph_qa.py | 41 ------------------------ 1 file changed, 41 deletions(-) diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index 012cf18..bbf6861 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -317,47 +317,6 @@ def execute_with_retry( raise ValueError("The generated gql query is invalid") - def execute_with_retry_bkp( - self, - _run_manager: CallbackManagerForChainRun, - intermediate_steps: List, - question: str, - gql_query: str, - ) -> tuple[str, List[Any]]: - try: - intermediate_steps.append({"generated_query": gql_query}) - return gql_query, self.execute_query(_run_manager, gql_query) - except Exception as e: - retries = 0 - err_msg = str(e) - self.log_invalid_query(_run_manager, gql_query, err_msg) - intermediate_steps.pop() - intermediate_steps.append({"query_failed_" + str(retries + 1): gql_query}) - - new_gql_query = "" - while retries < self.max_gql_fix_retries: - try: - fix_chain_result = self.gql_fix_chain.invoke( - { - "question": question, - "err_msg": err_msg, - "generated_gql": gql_query, - "schema": self.graph.get_schema, - } - ) - new_gql_query = extract_gql(fix_chain_result) - intermediate_steps.append({"generated_query": new_gql_query}) - return new_gql_query, self.execute_query( - _run_manager, new_gql_query - ) - except Exception as e: - retries += 1 - gql_query = new_gql_query - err_msg = str(e) - self.log_invalid_query(_run_manager, gql_query, err_msg) - - raise ValueError("The generated gql query is invalid") - def log_invalid_query( self, _run_manager: CallbackManagerForChainRun, From c8799eaf9c0cc858f5825dd7b3041f52f0e2c9d6 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 4 Dec 2024 06:42:09 +0000 Subject: [PATCH 14/40] fix type check errors --- src/langchain_google_spanner/graph_qa.py | 26 ++++++------------------ 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index bbf6861..e079f94 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -104,20 +104,6 @@ def extract_gql(text: str) -> str: return fix_gql_syntax(query) -def extract_verified_gql(json_response: str) -> str: - """Extract GQL query from a LLM response. - - Args: - response: Response to extract GQL query from. - - Returns: - GQL query extracted from the text. - """ - - json_response["verified_gql"] = fix_gql_syntax(str(json_response["verified_gql"])) - return json_response["verified_gql"] - - class SpannerGraphQAChain(Chain): """Chain for question-answering against a Spanner Graph database by generating GQL statements from natural language questions. @@ -200,12 +186,12 @@ def output_keys(self) -> List[str]: @classmethod def from_llm( cls, - llm: BaseLanguageModel = None, + llm: Optional[BaseLanguageModel] = None, *, - qa_prompt: BasePromptTemplate = None, - gql_prompt: BasePromptTemplate = None, - gql_verify_prompt: BasePromptTemplate = None, - gql_fix_prompt: BasePromptTemplate = None, + qa_prompt: Optional[BasePromptTemplate] = None, + gql_prompt: Optional[BasePromptTemplate] = None, + gql_verify_prompt: Optional[BasePromptTemplate] = None, + gql_fix_prompt: Optional[BasePromptTemplate] = None, qa_llm_kwargs: Optional[Dict[str, Any]] = None, gql_llm_kwargs: Optional[Dict[str, Any]] = None, gql_verify_llm_kwargs: Optional[Dict[str, Any]] = None, @@ -355,7 +341,7 @@ def _call( "graph_schema": self.graph.get_schema, } ) - verified_gql = extract_verified_gql(verify_response) + verified_gql = fix_gql_syntax(verify_response["verified_gql"]) intermediate_steps.append({"verified_gql": verified_gql}) else: verified_gql = generated_gql From 0d358aae597a27df4bb7c82b1ef359aadac3eab9 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 4 Dec 2024 23:31:31 +0000 Subject: [PATCH 15/40] Addressed review comments --- README.rst | 7 +++---- docs/graph_qa_chain.ipynb | 4 ++-- pyproject.toml | 6 ++++-- requirements.txt | 1 - src/langchain_google_spanner/graph_qa.py | 3 ++- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/README.rst b/README.rst index 5111312..cb047dd 100644 --- a/README.rst +++ b/README.rst @@ -151,15 +151,14 @@ See the full `Spanner Graph Store`_ tutorial. .. _`Spanner Graph Store`: https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_store.ipynb -Spanner Graph QA Usage -~~~~~~~~~~~~~~~~~~~~~~~~~~ +Spanner Graph QA Chain Usage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Use ``SpannerGraphQAChain`` for question answering over a graph stored in Spanner Graph. .. code:: python - from langchain_google_spanner import SpannerGraphQAChain - from langchain_google_spanner import SpannerGraphStore + from langchain_google_spanner import SpannerGraphStore, SpannerGraphQAChain from langchain_google_vertexai import ChatVertexAI diff --git a/docs/graph_qa_chain.ipynb b/docs/graph_qa_chain.ipynb index 4bc5ab3..7746e5c 100644 --- a/docs/graph_qa_chain.ipynb +++ b/docs/graph_qa_chain.ipynb @@ -36,7 +36,7 @@ "\n", "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-spanner-python/).\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-spanner-python/blob/main/docs/graph_store.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-spanner-python/blob/main/docs/graph_qa_chain.ipynb)" ], "metadata": { "id": "7VBkjcqNNxEd" @@ -690,4 +690,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/pyproject.toml b/pyproject.toml index 86f2884..b5c5131 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ authors = [ dependencies = [ "langchain-core>=0.1.25, <1.0.0", "langchain-community>=0.0.18, <1.0.0", - "google-cloud-spanner>=3.41.0, <4.0.0" + "google-cloud-spanner>=3.41.0, <4.0.0", + "pydantic>=2.9.1, <3.0.0" ] classifiers = [ "Intended Audience :: Developers", @@ -41,7 +42,8 @@ test = [ "mypy==1.11.2", "pytest==8.3.3", "pytest-asyncio==0.24.0", - "pytest-cov==5.0.0" + "pytest-cov==5.0.0", + "langchain_google_vertexai==1.0.10" ] [build-system] diff --git a/requirements.txt b/requirements.txt index 3a65981..9e16179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ google-cloud-spanner==3.49.1 langchain-core==0.3.9 langchain-community==0.3.1 -langchain_google_vertexai pydantic==2.9.1 diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index e079f94..1e9ead2 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -26,7 +26,8 @@ from langchain_core.runnables import RunnableSequence from pydantic.v1 import BaseModel, Field -from .graph_store import SpannerGraphStore +from langchain_google_spanner.graph_store import SpannerGraphStore + from .prompts import ( DEFAULT_GQL_FIX_TEMPLATE, DEFAULT_GQL_TEMPLATE, From 340eadc0441331500ec458332a08ecfc40d51117 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 5 Dec 2024 00:57:27 +0000 Subject: [PATCH 16/40] Addressed review comments --- src/langchain_google_spanner/graph_qa.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index 1e9ead2..ff399b4 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -73,6 +73,16 @@ class VerifyGqlOutput(BaseModel): def fix_gql_syntax(query: str) -> str: """Fixes the syntax of a GQL query. + Example 1: + Input: + MATCH (p:paper {id: 0})-[c:cites*8]->(p2:paper) + Output: + MATCH (p:paper {id: 0})-[c:cites]->{8}(p2:paper) + Example 2: + Input: + MATCH (p:paper {id: 0})-[c:cites*1..8]->(p2:paper) + Output: + MATCH (p:paper {id: 0})-[c:cites]->{1:8}(p2:paper) Args: query: The input GQL query. @@ -352,6 +362,8 @@ def _call( (final_gql, context) = self.execute_with_retry( _run_manager, intermediate_steps, question, verified_gql ) + if not final_gql: + raise ValueError("No GQL was generated.") _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) _run_manager.on_text( str(context), color="green", end="\n", verbose=self.verbose From 1439f044bc8915d13647de1ee2ba67e137ad0c36 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 5 Dec 2024 19:03:11 +0000 Subject: [PATCH 17/40] Clear default project id from notebook --- docs/graph_qa_chain.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/graph_qa_chain.ipynb b/docs/graph_qa_chain.ipynb index 7746e5c..26bda8a 100644 --- a/docs/graph_qa_chain.ipynb +++ b/docs/graph_qa_chain.ipynb @@ -150,7 +150,7 @@ "source": [ "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", "\n", - "PROJECT_ID = \"google.com:cloud-spanner-demo\" # @param {type:\"string\"}\n", + "PROJECT_ID = \"\" # @param {type:\"string\"}\n", "\n", "# Set the project id\n", "!gcloud config set project {PROJECT_ID}\n", From 06d048912f87fff0fe781a0714fe35fd3be75a19 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 5 Dec 2024 19:13:08 +0000 Subject: [PATCH 18/40] Add import statement for SpanerGraphQAChain to notebook --- docs/graph_qa_chain.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/graph_qa_chain.ipynb b/docs/graph_qa_chain.ipynb index 26bda8a..4d47f62 100644 --- a/docs/graph_qa_chain.ipynb +++ b/docs/graph_qa_chain.ipynb @@ -494,6 +494,7 @@ "cell_type": "code", "source": [ "from google.cloud import spanner\n", + "from langchain_google_spanner import SpannerGraphQAChain\n", "from langchain_google_vertexai import ChatVertexAI\n", "from IPython.core.display import HTML\n", "\n", From a086ff650f70da05d28930310aff6e327e560a17 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 12 Dec 2024 20:33:22 +0000 Subject: [PATCH 19/40] Add retrievers for Spanner Graph RAG --- src/langchain_google_spanner/__init__.py | 8 + src/langchain_google_spanner/graph_qa.py | 46 +-- .../graph_retriever.py | 276 ++++++++++++++++++ src/langchain_google_spanner/graph_utils.py | 62 ++++ .../test_spanner_graph_retriever.py | 223 ++++++++++++++ 5 files changed, 570 insertions(+), 45 deletions(-) create mode 100644 src/langchain_google_spanner/graph_retriever.py create mode 100644 src/langchain_google_spanner/graph_utils.py create mode 100644 tests/integration/test_spanner_graph_retriever.py diff --git a/src/langchain_google_spanner/__init__.py b/src/langchain_google_spanner/__init__.py index 28c2dd1..430a7fd 100644 --- a/src/langchain_google_spanner/__init__.py +++ b/src/langchain_google_spanner/__init__.py @@ -15,6 +15,11 @@ from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory from langchain_google_spanner.graph_qa import SpannerGraphQAChain from langchain_google_spanner.graph_store import SpannerGraphStore +from langchain_google_spanner.graph_retriever import ( + SpannerGraphGQLRetriever, + SpannerGraphNodeVectorRetriever, + SpannerGraphSemanticGQLRetriever, +) from langchain_google_spanner.vector_store import ( DistanceStrategy, QueryParameters, @@ -38,4 +43,7 @@ "SecondaryIndex", "QueryParameters", "DistanceStrategy", + "SpannerGraphGQLRetriever", + "SpannerGraphNodeVectorRetriever", + "SpannerGraphSemanticGQLRetriever", ] diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index ff399b4..f071773 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -14,7 +14,6 @@ from __future__ import annotations -import re from typing import Any, Dict, List, Optional from langchain.chains.base import Chain @@ -28,6 +27,7 @@ from langchain_google_spanner.graph_store import SpannerGraphStore +from .graph_utils import extract_gql, fix_gql_syntax from .prompts import ( DEFAULT_GQL_FIX_TEMPLATE, DEFAULT_GQL_TEMPLATE, @@ -71,50 +71,6 @@ class VerifyGqlOutput(BaseModel): INTERMEDIATE_STEPS_KEY = "intermediate_steps" -def fix_gql_syntax(query: str) -> str: - """Fixes the syntax of a GQL query. - Example 1: - Input: - MATCH (p:paper {id: 0})-[c:cites*8]->(p2:paper) - Output: - MATCH (p:paper {id: 0})-[c:cites]->{8}(p2:paper) - Example 2: - Input: - MATCH (p:paper {id: 0})-[c:cites*1..8]->(p2:paper) - Output: - MATCH (p:paper {id: 0})-[c:cites]->{1:8}(p2:paper) - - Args: - query: The input GQL query. - - Returns: - Possibly modified GQL query. - """ - - query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]->", r"-[\1:\2]->{\3,\4}", query) - query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]->", r"-[\1:\2]->{\3}", query) - query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"<-[\1:\2]-{\3,\4}", query) - query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\]-", r"<-[\1:\2]-{\3}", query) - query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"-[\1:\2]-{\3,\4}", query) - query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]-", r"-[\1:\2]-{\3}", query) - return query - - -def extract_gql(text: str) -> str: - """Extract GQL query from a text. - - Args: - text: Text to extract GQL query from. - - Returns: - GQL query extracted from the text. - """ - pattern = r"```(.*?)```" - matches = re.findall(pattern, text, re.DOTALL) - query = matches[0] if matches else text - return fix_gql_syntax(query) - - class SpannerGraphQAChain(Chain): """Chain for question-answering against a Spanner Graph database by generating GQL statements from natural language questions. diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py new file mode 100644 index 0000000..1642193 --- /dev/null +++ b/src/langchain_google_spanner/graph_retriever.py @@ -0,0 +1,276 @@ +import json +from typing import Any, List + +from langchain.schema.retriever import BaseRetriever +from langchain_community.graphs.graph_document import GraphDocument +from langchain_core.callbacks import ( + CallbackManagerForChainRun, + CallbackManagerForRetrieverRun, +) +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_core.language_models import BaseLanguageModel +from langchain_core.load import dumps +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import FewShotPromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.vectorstores import InMemoryVectorStore +from pydantic import Field + +from langchain_google_spanner.graph_store import SpannerGraphStore +from langchain_google_spanner.vector_store import DistanceStrategy, QueryParameters + +from .graph_utils import extract_gql +from .prompts import DEFAULT_GQL_TEMPLATE, DEFAULT_GQL_TEMPLATE_PART1 + +GQL_GENERATION_PROMPT = PromptTemplate( + template=DEFAULT_GQL_TEMPLATE, + input_variables=["question", "schema"], +) + + +def graph_doc_to_doc(graph_doc: GraphDocument) -> Document: + """Converts a GraphDocument to a Document.""" + content = dumps(graph_doc, pretty=True) + return Document(page_content=content, metadata={}) + + +def get_distance_function(distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: + """Gets the vector distance function.""" + if distance_strategy == DistanceStrategy.COSINE: + return "COSINE_DISTANCE" + + return "EUCLIDEAN_DISTANCE" + + +def get_graph_name_from_schema(schema: str): + return json.loads(schema)["Name of graph"] + + +def duplicate_braces_in_string(text): + """Replaces single curly braces with double curly braces in a string. + + Args: + text: The input string. + + Returns: + The modified string with double curly braces. + """ + text = text.replace("{", "{{") + text = text.replace("}", "}}") + return text + + +class SpannerGraphGQLRetriever(BaseRetriever): + """A Retriever that translates natural language queries to GQL and + queries SpannerGraphStore using the GQL. + Returns the documents retrieved as result. + """ + + graph_store: SpannerGraphStore = Field(exclude=True) + llm: BaseLanguageModel = None + k: int = 10 + """Number of top results to return""" + + def _get_relevant_documents( + self, question: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Translate the natural language query to GQL, execute it, + and return the results as Documents. + """ + + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + # Initialize the chain + gql_chain = GQL_GENERATION_PROMPT | self.llm | StrOutputParser() + # 1. Generate gql query from natural language query using LLM + gql_query = extract_gql( + gql_chain.invoke( + { + "question": question, + "schema": self.graph_store.get_schema, + } + ) + ) + print(gql_query) + + # 2. Execute the gql query against spanner graph + try: + graph_documents = self.graph_store.query(gql_query)[: self.k] + except Exception as e: + raise ValueError(str(e)) + + # 3. Transform the results into a list of Documents + documents = [] + for graph_document in graph_documents: + documents.append(graph_doc_to_doc(graph_document)) + return documents + + +class SpannerGraphSemanticGQLRetriever(BaseRetriever): + """A Retriever that translates natural language queries to GQL and + and queries SpannerGraphStore to retrieve documents. It uses a semantic + similarity model to compare the input question to a set of examples to + generate the GQL query. + """ + + graph_store: SpannerGraphStore = Field(exclude=True) + llm: BaseLanguageModel = None + k: int = 10 + """Number of top results to return""" + selector: SemanticSimilarityExampleSelector = None + + @classmethod + def from_llm( + cls, embedding_service: Embeddings = None, **kwargs: Any + ) -> "SpannerGraphSemanticGQLRetriever": + selector = SemanticSimilarityExampleSelector.from_examples( + [], embedding_service, InMemoryVectorStore(embedding_service), k=2 + ) + return cls( + selector=selector, + **kwargs, + ) + + def add_example(self, question: str, gql: str): + self.selector.add_example( + {"input": question, "query": duplicate_braces_in_string(gql)} + ) + + def _get_relevant_documents( + self, question: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Translate the natural language query to GQL using examples looked up + by a semantic similarity model, execute it, and return the results as + Documents. + """ + + # Define the prompt template + prompt = FewShotPromptTemplate( + example_selector=self.selector, + example_prompt=PromptTemplate.from_template( + "Question: {input}\nGQL Query: {query}" + ), + prefix=""" + Create an ISO GQL query for the question using the schema.""", + suffix=DEFAULT_GQL_TEMPLATE_PART1, + input_variables=["question", "schema"], + ) + + # Initialize the chain + gql_chain = prompt | self.llm | StrOutputParser() + # 1. Generate gql query from natural language query using LLM + gql_query = extract_gql( + gql_chain.invoke( + { + "question": question, + "schema": self.graph_store.get_schema, + } + ) + ) + print(gql_query) # TODO(amullick): REMOVE + + # 2. Execute the gql query against spanner graph + try: + graph_documents = self.graph_store.query(gql_query)[: self.k] + except Exception as e: + raise ValueError(str(e)) + + # 3. Transform the results into a list of Documents + documents = [] + for graph_document in graph_documents: + documents.append(graph_doc_to_doc(graph_document)) + return documents + + +class SpannerGraphNodeVectorRetriever(BaseRetriever): + """Retriever that does a vector search on nodes in a SpannerGraphStore. + If a graph expansion query is provided, it will be executed after the + initial vector search to expand the returned context. + """ + + graph_store: SpannerGraphStore = Field(exclude=True) + embedding_service: Embeddings + label_expr: str = "%" + return_properties_list: List[str] = [] + embeddings_column: str = "embedding" + query_parameters: QueryParameters = QueryParameters() + k: int = 10 + """Number of top results to return""" + graph_expansion_query: str = None + """GQL query to expand the returned context""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + print(self.return_properties_list) + print(self.graph_expansion_query) + if not self.return_properties_list and self.graph_expansion_query is None: + raise ValueError( + "Either `return_properties` or `graph_expansion_query` must be provided." + ) + + def _get_relevant_documents( + self, question: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Translate the natural language query to GQL, execute it, + and return the results as Documents.""" + + schema = self.graph_store.get_schema + graph_name = get_graph_name_from_schema(schema) + node_variable = "node" + query_embeddings = self.embedding_service.embed_query(question) + + distance_fn = get_distance_function(self.query_parameters.distance_strategy) + + VECTOR_QUERY = """ + GRAPH {graph_name} + MATCH ({node_var}:{label_expr}) + ORDER BY {distance_fn}({node_var}.{embeddings_column}, + ARRAY[{query_embeddings}]) + LIMIT {k} + """ + gql_query = VECTOR_QUERY.format( + graph_name=graph_name, + node_var=node_variable, + label_expr=self.label_expr, + embeddings_column=self.embeddings_column, + distance_fn=distance_fn, + query_embeddings=",".join(map(str, query_embeddings)), + k=self.k, + ) + + if self.return_properties_list: + return_properties = ",".join( + map(lambda x: node_variable + "." + x, self.return_properties_list) + ) + gql_query += """ + RETURN {} + """.format( + return_properties + ) + elif self.graph_expansion_query is not None: + gql_query += """ + RETURN node + NEXT + {} + """.format( + self.graph_expansion_query + ) + else: + raise ValueError( + "Either `return_properties` or `graph_expansion_query` must be provided." + ) + + print(gql_query) + + # 2. Execute the gql query against spanner graph + try: + graph_documents = self.graph_store.query(gql_query)[: self.k] + except Exception as e: + raise ValueError(str(e)) + + # 3. Transform the results into a list of Documents + documents = [] + for graph_document in graph_documents: + documents.append(graph_doc_to_doc(graph_document)) + return documents diff --git a/src/langchain_google_spanner/graph_utils.py b/src/langchain_google_spanner/graph_utils.py new file mode 100644 index 0000000..48411fc --- /dev/null +++ b/src/langchain_google_spanner/graph_utils.py @@ -0,0 +1,62 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import re + + +def fix_gql_syntax(query: str) -> str: + """Fixes the syntax of a GQL query. + Example 1: + Input: + MATCH (p:paper {id: 0})-[c:cites*8]->(p2:paper) + Output: + MATCH (p:paper {id: 0})-[c:cites]->{8}(p2:paper) + Example 2: + Input: + MATCH (p:paper {id: 0})-[c:cites*1..8]->(p2:paper) + Output: + MATCH (p:paper {id: 0})-[c:cites]->{1:8}(p2:paper) + + Args: + query: The input GQL query. + + Returns: + Possibly modified GQL query. + """ + + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]->", r"-[\1:\2]->{\3,\4}", query) + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]->", r"-[\1:\2]->{\3}", query) + query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"<-[\1:\2]-{\3,\4}", query) + query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\]-", r"<-[\1:\2]-{\3}", query) + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"-[\1:\2]-{\3,\4}", query) + query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]-", r"-[\1:\2]-{\3}", query) + return query + + +def extract_gql(text: str) -> str: + """Extract GQL query from a text. + + Args: + text: Text to extract GQL query from. + + Returns: + GQL query extracted from the text. + """ + pattern = r"```(.*?)```" + matches = re.findall(pattern, text, re.DOTALL) + query = matches[0] if matches else text + return fix_gql_syntax(query) diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py new file mode 100644 index 0000000..bf72275 --- /dev/null +++ b/tests/integration/test_spanner_graph_retriever.py @@ -0,0 +1,223 @@ +import os +import random +import string + +import pytest +from google.cloud import spanner +from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship +from langchain_core.documents import Document +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings + +from langchain_google_spanner.graph_retriever import ( + SpannerGraphGQLRetriever, + SpannerGraphNodeVectorRetriever, + SpannerGraphSemanticGQLRetriever, +) +from langchain_google_spanner.graph_store import SpannerGraphStore + +project_id = os.environ["PROJECT_ID"] +instance_id = os.environ["INSTANCE_ID"] +database_id = os.environ["DATABASE_ID"] + + +def random_string(num_char=3): + return "".join(random.choice(string.ascii_letters) for _ in range(num_char)) + + +def get_llm(): + llm = ChatVertexAI( + model="gemini-1.5-flash-002", + temperature=0, + ) + return llm + + +def get_embedding(): + embeddings = VertexAIEmbeddings(model_name="text-embedding-004") + return embeddings + + +def get_spanner_graph(): + suffix = random_string(num_char=3) + graph_name = "test_graph{}".format(suffix) + graph = SpannerGraphStore( + instance_id=instance_id, + database_id=database_id, + graph_name=graph_name, + client=spanner.Client(project=project_id), + ) + return graph, suffix + + +def load_data(graph: SpannerGraphStore, suffix: str): + type_suffix = "_" + suffix + graph_documents = [ + GraphDocument( + nodes=[ + Node( + id="Elias Thorne", + type="Person" + type_suffix, + properties={ + "name": "Elias Thorne", + "description": "lived in the desert", + }, + ), + Node( + id="Zephyr", + type="Animal" + type_suffix, + properties={"name": "Zephyr", "description": "pet falcon"}, + ), + Node( + id="Elara", + type="Person" + type_suffix, + properties={ + "name": "Elara", + "description": "resided in the capital city", + }, + ), + Node(id="Desert", type="Location" + type_suffix, properties={}), + Node(id="Capital City", type="Location" + type_suffix, properties={}), + ], + relationships=[ + Relationship( + source=Node( + id="Elias Thorne", type="Person" + type_suffix, properties={} + ), + target=Node( + id="Desert", type="Location" + type_suffix, properties={} + ), + type="LivesIn", + properties={}, + ), + Relationship( + source=Node( + id="Elias Thorne", type="Person" + type_suffix, properties={} + ), + target=Node( + id="Zephyr", type="Animal" + type_suffix, properties={} + ), + type="Owns", + properties={}, + ), + Relationship( + source=Node(id="Elara", type="Person" + type_suffix, properties={}), + target=Node( + id="Capital City", type="Location" + type_suffix, properties={} + ), + type="LivesIn", + properties={}, + ), + Relationship( + source=Node( + id="Elias Thorne", type="Person" + type_suffix, properties={} + ), + target=Node(id="Elara", type="Person" + type_suffix, properties={}), + type="Sibling", + properties={}, + ), + ], + source=Document( + metadata={}, + page_content=( + "Elias Thorne lived in the desert. He was a skilled craftsman" + " who worked with sandstone. Elias had a pet falcon named" + " Zephyr. His sister, Elara, resided in the capital city and" + " ran a spice shop. They rarely met due to the distance." + ), + ), + ) + ] + + # Add embeddings to the graph documents for Person nodes + embedding_service = get_embedding() + for graph_document in graph_documents: + for node in graph_document.nodes: + if node.type == "Person{}".format(type_suffix): + if "description" in node.properties: + node.properties["desc_embedding"] = embedding_service.embed_query( + node.properties["description"] + ) + graph.add_graph_documents(graph_documents) + graph.refresh_schema() + + +class TestRetriever: + + @pytest.fixture + def setup_db_load_data(self): + graph, suffix = get_spanner_graph() + load_data(graph, suffix) + yield graph, suffix + # teardown + graph.cleanup() + + def test_spanner_graph_gql_retriever(self, setup_db_load_data): + graph, suffix = setup_db_load_data + retriever = SpannerGraphGQLRetriever( + graph_store=graph, + llm=get_llm(), + ) + response = retriever.invoke("Where does Elias Thorne's sibling live?") + + assert len(response) == 1 + assert "Capital City" in response[0].page_content + + def test_spanner_graph_semantic_gql_retriever(self, setup_db_load_data): + graph, suffix = setup_db_load_data + suffix = "_" + suffix + retriever = SpannerGraphSemanticGQLRetriever.from_llm( + graph_store=graph, + llm=get_llm(), + embedding_service=get_embedding(), + ) + retriever.add_example( + "Where does Sam Smith live?", + """ + GRAPH QAGraph + MATCH (n:Person{suffix} {{name: "Sam Smith"}})-[:LivesIn]->(l:Location{suffix}) + RETURN l.id AS location_id + """.format( + suffix=suffix + ), + ) + retriever.add_example( + "Where does Sam Smith's sibling live?", + """ + GRAPH QAGraph + MATCH (n:Person{suffix} {{name: "Sam Smith"}})-[:Sibling]->(m:Person{suffix})-[:LivesIn]->(l:Location{suffix}) + RETURN l.id AS location_id + """.format( + suffix=suffix + ), + ) + response = retriever.invoke("Where does Elias Thorne's sibling live?") + assert response == [ + Document(metadata={}, page_content='{\n "location_id": "Capital City"\n}') + ] + + def test_spanner_graph_vector_node_retriever_error(self, setup_db_load_data): + with pytest.raises(ValueError): + graph, suffix = setup_db_load_data + suffix = "_" + suffix + SpannerGraphNodeVectorRetriever( + graph_store=graph, + embedding_service=get_embedding(), + label_expr="Person{}".format(suffix), + embeddings_column="desc_embedding", + k=1, + ) + + def test_spanner_graph_vector_node_retriever(self, setup_db_load_data): + graph, suffix = setup_db_load_data + suffix = "_" + suffix + retriever = SpannerGraphNodeVectorRetriever( + graph_store=graph, + embedding_service=get_embedding(), + label_expr="Person{}".format(suffix), + return_properties_list=["name"], + embeddings_column="desc_embedding", + k=1, + ) + response = retriever.invoke("Who lives in desert?") + assert len(response) == 1 + assert "Elias Thorne" in response[0].page_content From 1ae9fec8b60a84440f180fe3521967717b7f26af Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 12 Dec 2024 20:39:50 +0000 Subject: [PATCH 20/40] Add licence headers --- src/langchain_google_spanner/graph_retriever.py | 14 ++++++++++++++ tests/integration/test_spanner_graph_retriever.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 1642193..57a7ba8 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -1,3 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from typing import Any, List diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py index bf72275..9a112a9 100644 --- a/tests/integration/test_spanner_graph_retriever.py +++ b/tests/integration/test_spanner_graph_retriever.py @@ -1,3 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import random import string From 704ed322dffdbdd608a10a7493351306616d238d Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 12 Dec 2024 20:46:35 +0000 Subject: [PATCH 21/40] Fix DATABASE name key --- tests/integration/test_spanner_graph_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py index 9a112a9..5038c13 100644 --- a/tests/integration/test_spanner_graph_retriever.py +++ b/tests/integration/test_spanner_graph_retriever.py @@ -31,7 +31,7 @@ project_id = os.environ["PROJECT_ID"] instance_id = os.environ["INSTANCE_ID"] -database_id = os.environ["DATABASE_ID"] +database_id = os.environ["GOOGLE_DATABASE"] def random_string(num_char=3): From 11f674ef86b2f05265daa9422869e0fd4a590df1 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Thu, 12 Dec 2024 21:46:54 +0000 Subject: [PATCH 22/40] Fix lint error on import ordering --- src/langchain_google_spanner/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/langchain_google_spanner/__init__.py b/src/langchain_google_spanner/__init__.py index 430a7fd..84383ad 100644 --- a/src/langchain_google_spanner/__init__.py +++ b/src/langchain_google_spanner/__init__.py @@ -14,12 +14,12 @@ from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory from langchain_google_spanner.graph_qa import SpannerGraphQAChain -from langchain_google_spanner.graph_store import SpannerGraphStore from langchain_google_spanner.graph_retriever import ( SpannerGraphGQLRetriever, SpannerGraphNodeVectorRetriever, SpannerGraphSemanticGQLRetriever, ) +from langchain_google_spanner.graph_store import SpannerGraphStore from langchain_google_spanner.vector_store import ( DistanceStrategy, QueryParameters, From 1daf9e668eef46f0259b544bed7fbbd48fd94c32 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Fri, 13 Dec 2024 01:38:53 +0000 Subject: [PATCH 23/40] Fix lint errors --- .../graph_retriever.py | 87 ++++++++++++++----- .../test_spanner_graph_retriever.py | 8 +- 2 files changed, 67 insertions(+), 28 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 57a7ba8..120805a 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import Any, List +from typing import Any, List, Optional from langchain.schema.retriever import BaseRetriever from langchain_community.graphs.graph_document import GraphDocument @@ -29,6 +29,7 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import FewShotPromptTemplate from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.runnables import RunnableSequence from langchain_core.vectorstores import InMemoryVectorStore from pydantic import Field @@ -44,9 +45,9 @@ ) -def graph_doc_to_doc(graph_doc: GraphDocument) -> Document: - """Converts a GraphDocument to a Document.""" - content = dumps(graph_doc, pretty=True) +def convert_to_doc(data: dict[str, Any]) -> Document: + """Converts data to a Document.""" + content = dumps(data, pretty=True) return Document(page_content=content, metadata={}) @@ -83,10 +84,19 @@ class SpannerGraphGQLRetriever(BaseRetriever): """ graph_store: SpannerGraphStore = Field(exclude=True) - llm: BaseLanguageModel = None + gql_chain: RunnableSequence k: int = 10 """Number of top results to return""" + @classmethod + def from_params( + cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any + ) -> "SpannerGraphGQLRetriever": + if llm is None: + raise ValueError("`llm` cannot be none") + gql_chain = GQL_GENERATION_PROMPT | llm | StrOutputParser() + return cls(gql_chain=gql_chain, **kwargs) + def _get_relevant_documents( self, question: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: @@ -94,12 +104,9 @@ def _get_relevant_documents( and return the results as Documents. """ - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - # Initialize the chain - gql_chain = GQL_GENERATION_PROMPT | self.llm | StrOutputParser() # 1. Generate gql query from natural language query using LLM gql_query = extract_gql( - gql_chain.invoke( + self.gql_chain.invoke( { "question": question, "schema": self.graph_store.get_schema, @@ -110,14 +117,14 @@ def _get_relevant_documents( # 2. Execute the gql query against spanner graph try: - graph_documents = self.graph_store.query(gql_query)[: self.k] + responses = self.graph_store.query(gql_query)[: self.k] except Exception as e: raise ValueError(str(e)) # 3. Transform the results into a list of Documents documents = [] - for graph_document in graph_documents: - documents.append(graph_doc_to_doc(graph_document)) + for response in responses: + documents.append(convert_to_doc(response)) return documents @@ -129,24 +136,34 @@ class SpannerGraphSemanticGQLRetriever(BaseRetriever): """ graph_store: SpannerGraphStore = Field(exclude=True) - llm: BaseLanguageModel = None k: int = 10 """Number of top results to return""" - selector: SemanticSimilarityExampleSelector = None + llm: Optional[BaseLanguageModel] = None + selector: Optional[SemanticSimilarityExampleSelector] = None @classmethod - def from_llm( - cls, embedding_service: Embeddings = None, **kwargs: Any + def from_params( + cls, + llm: Optional[BaseLanguageModel] = None, + embedding_service: Optional[Embeddings] = None, + **kwargs: Any, ) -> "SpannerGraphSemanticGQLRetriever": + if llm is None: + raise ValueError("`llm` cannot be none") + if embedding_service is None: + raise ValueError("`embedding_service` cannot be none") selector = SemanticSimilarityExampleSelector.from_examples( - [], embedding_service, InMemoryVectorStore(embedding_service), k=2 + [], embedding_service, InMemoryVectorStore, k=2 ) return cls( + llm=llm, selector=selector, **kwargs, ) def add_example(self, question: str, gql: str): + if self.selector is None: + raise ValueError("`selector` cannot be None") self.selector.add_example( {"input": question, "query": duplicate_braces_in_string(gql)} ) @@ -159,6 +176,11 @@ def _get_relevant_documents( Documents. """ + if self.llm is None: + raise ValueError("`llm` cannot be None") + if self.selector is None: + raise ValueError("`selector` cannot be None") + # Define the prompt template prompt = FewShotPromptTemplate( example_selector=self.selector, @@ -186,14 +208,14 @@ def _get_relevant_documents( # 2. Execute the gql query against spanner graph try: - graph_documents = self.graph_store.query(gql_query)[: self.k] + responses = self.graph_store.query(gql_query)[: self.k] except Exception as e: raise ValueError(str(e)) # 3. Transform the results into a list of Documents documents = [] - for graph_document in graph_documents: - documents.append(graph_doc_to_doc(graph_document)) + for response in responses: + documents.append(convert_to_doc(response)) return documents @@ -204,21 +226,35 @@ class SpannerGraphNodeVectorRetriever(BaseRetriever): """ graph_store: SpannerGraphStore = Field(exclude=True) - embedding_service: Embeddings + embedding_service: Optional[Embeddings] = None label_expr: str = "%" + """A label expression for the nodes to search""" return_properties_list: List[str] = [] + """The list of properties to return""" embeddings_column: str = "embedding" + """The name of the column that stores embedding""" query_parameters: QueryParameters = QueryParameters() k: int = 10 """Number of top results to return""" - graph_expansion_query: str = None + graph_expansion_query: str = "" """GQL query to expand the returned context""" + @classmethod + def from_params( + cls, embedding_service: Optional[Embeddings] = None, **kwargs: Any + ) -> "SpannerGraphNodeVectorRetriever": + if embedding_service is None: + raise ValueError("`embedding_service` cannot be None") + return cls( + embedding_service=embedding_service, + **kwargs, + ) + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) print(self.return_properties_list) print(self.graph_expansion_query) - if not self.return_properties_list and self.graph_expansion_query is None: + if not self.return_properties_list and not self.graph_expansion_query: raise ValueError( "Either `return_properties` or `graph_expansion_query` must be provided." ) @@ -229,6 +265,9 @@ def _get_relevant_documents( """Translate the natural language query to GQL, execute it, and return the results as Documents.""" + if self.embedding_service is None: + raise ValueError("`embedding_service` cannot be None") + schema = self.graph_store.get_schema graph_name = get_graph_name_from_schema(schema) node_variable = "node" @@ -286,5 +325,5 @@ def _get_relevant_documents( # 3. Transform the results into a list of Documents documents = [] for graph_document in graph_documents: - documents.append(graph_doc_to_doc(graph_document)) + documents.append(convert_to_doc(graph_document)) return documents diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py index 5038c13..80e25ee 100644 --- a/tests/integration/test_spanner_graph_retriever.py +++ b/tests/integration/test_spanner_graph_retriever.py @@ -167,7 +167,7 @@ def setup_db_load_data(self): def test_spanner_graph_gql_retriever(self, setup_db_load_data): graph, suffix = setup_db_load_data - retriever = SpannerGraphGQLRetriever( + retriever = SpannerGraphGQLRetriever.from_params( graph_store=graph, llm=get_llm(), ) @@ -179,7 +179,7 @@ def test_spanner_graph_gql_retriever(self, setup_db_load_data): def test_spanner_graph_semantic_gql_retriever(self, setup_db_load_data): graph, suffix = setup_db_load_data suffix = "_" + suffix - retriever = SpannerGraphSemanticGQLRetriever.from_llm( + retriever = SpannerGraphSemanticGQLRetriever.from_params( graph_store=graph, llm=get_llm(), embedding_service=get_embedding(), @@ -213,7 +213,7 @@ def test_spanner_graph_vector_node_retriever_error(self, setup_db_load_data): with pytest.raises(ValueError): graph, suffix = setup_db_load_data suffix = "_" + suffix - SpannerGraphNodeVectorRetriever( + SpannerGraphNodeVectorRetriever.from_params( graph_store=graph, embedding_service=get_embedding(), label_expr="Person{}".format(suffix), @@ -224,7 +224,7 @@ def test_spanner_graph_vector_node_retriever_error(self, setup_db_load_data): def test_spanner_graph_vector_node_retriever(self, setup_db_load_data): graph, suffix = setup_db_load_data suffix = "_" + suffix - retriever = SpannerGraphNodeVectorRetriever( + retriever = SpannerGraphNodeVectorRetriever.from_params( graph_store=graph, embedding_service=get_embedding(), label_expr="Person{}".format(suffix), From c46ee65a014de65360507c95e6954103bd2afbb4 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Mon, 16 Dec 2024 23:17:51 +0000 Subject: [PATCH 24/40] Few minor changes to the SpannerGraphNodeVectorRetriever --- src/langchain_google_spanner/graph_retriever.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 120805a..1e8a492 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -234,10 +234,12 @@ class SpannerGraphNodeVectorRetriever(BaseRetriever): embeddings_column: str = "embedding" """The name of the column that stores embedding""" query_parameters: QueryParameters = QueryParameters() - k: int = 10 - """Number of top results to return""" + top_k: int = 3 + """Number of vector similarity matches to return""" graph_expansion_query: str = "" """GQL query to expand the returned context""" + k: int = 10 + """Number of graph results to return""" @classmethod def from_params( @@ -278,9 +280,10 @@ def _get_relevant_documents( VECTOR_QUERY = """ GRAPH {graph_name} MATCH ({node_var}:{label_expr}) + WHERE {node_var}.{embeddings_column} IS NOT NULL ORDER BY {distance_fn}({node_var}.{embeddings_column}, ARRAY[{query_embeddings}]) - LIMIT {k} + LIMIT {top_k} """ gql_query = VECTOR_QUERY.format( graph_name=graph_name, @@ -289,7 +292,7 @@ def _get_relevant_documents( embeddings_column=self.embeddings_column, distance_fn=distance_fn, query_embeddings=",".join(map(str, query_embeddings)), - k=self.k, + top_k=self.top_k, ) if self.return_properties_list: From 67d9c5e98b03ddfe6987fd055014363f9a51dcfb Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Mon, 16 Dec 2024 23:48:35 +0000 Subject: [PATCH 25/40] Fix lint error --- src/langchain_google_spanner/graph_retriever.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 1e8a492..b73eed2 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -94,7 +94,9 @@ def from_params( ) -> "SpannerGraphGQLRetriever": if llm is None: raise ValueError("`llm` cannot be none") - gql_chain = GQL_GENERATION_PROMPT | llm | StrOutputParser() + gql_chain: RunnableSequence = RunnableSequence( + GQL_GENERATION_PROMPT | llm | StrOutputParser() + ) return cls(gql_chain=gql_chain, **kwargs) def _get_relevant_documents( From b3e4e3c1d33a1fcafc164a471af2defaa7da21fa Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 18 Dec 2024 23:37:36 +0000 Subject: [PATCH 26/40] Add an option to expand context graph by hops --- .../graph_retriever.py | 78 +++++++++++++++---- .../test_spanner_graph_retriever.py | 23 +++++- 2 files changed, 86 insertions(+), 15 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index b73eed2..33bf773 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -47,7 +47,7 @@ def convert_to_doc(data: dict[str, Any]) -> Document: """Converts data to a Document.""" - content = dumps(data, pretty=True) + content = dumps(data) return Document(page_content=content, metadata={}) @@ -77,6 +77,31 @@ def duplicate_braces_in_string(text): return text +def clean_element(element, embedding_column): + """Removes specified keys and embedding from properties in graph element. + + Args: + element: A dictionary representing element + + Returns: + A cleaned dictionary with the specified keys removed. + """ + + keys_to_remove = [ + "source_node_identifier", + "destination_node_identifier", + "identifier", + ] + for key in keys_to_remove: + if key in element: + del element[key] + + if "properties" in element and embedding_column in element["properties"]: + del element["properties"][embedding_column] + + return element + + class SpannerGraphGQLRetriever(BaseRetriever): """A Retriever that translates natural language queries to GQL and queries SpannerGraphStore using the GQL. @@ -206,7 +231,7 @@ def _get_relevant_documents( } ) ) - print(gql_query) # TODO(amullick): REMOVE + print(gql_query) # 2. Execute the gql query against spanner graph try: @@ -240,6 +265,8 @@ class SpannerGraphNodeVectorRetriever(BaseRetriever): """Number of vector similarity matches to return""" graph_expansion_query: str = "" """GQL query to expand the returned context""" + expand_by_hops: int = -1 + """Number of hops to traverse to expand graph results""" k: int = 10 """Number of graph results to return""" @@ -256,11 +283,19 @@ def from_params( def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - print(self.return_properties_list) - print(self.graph_expansion_query) - if not self.return_properties_list and not self.graph_expansion_query: + if self.embedding_service is None: + raise ValueError("`embedding_service` cannot be None") + + sum = 0 + if self.return_properties_list: + sum += 1 + if self.graph_expansion_query: + sum += 1 + if self.expand_by_hops != -1: + sum += 1 + if sum != 1: raise ValueError( - "Either `return_properties` or `graph_expansion_query` must be provided." + "One and only one of `return_properties` or `graph_expansion_query` or `expand_by_hops` must be provided." ) def _get_relevant_documents( @@ -269,9 +304,6 @@ def _get_relevant_documents( """Translate the natural language query to GQL, execute it, and return the results as Documents.""" - if self.embedding_service is None: - raise ValueError("`embedding_service` cannot be None") - schema = self.graph_store.get_schema graph_name = get_graph_name_from_schema(schema) node_variable = "node" @@ -297,7 +329,16 @@ def _get_relevant_documents( top_k=self.top_k, ) - if self.return_properties_list: + if self.expand_by_hops >= 0: + gql_query += """ + RETURN node + NEXT + MATCH p = (node) -[]-{{0,{}}} () + RETURN SAFE_TO_JSON(p) as path + """.format( + self.expand_by_hops + ) + elif self.return_properties_list: return_properties = ",".join( map(lambda x: node_variable + "." + x, self.return_properties_list) ) @@ -323,12 +364,23 @@ def _get_relevant_documents( # 2. Execute the gql query against spanner graph try: - graph_documents = self.graph_store.query(gql_query)[: self.k] + responses = self.graph_store.query(gql_query)[: self.k] except Exception as e: raise ValueError(str(e)) # 3. Transform the results into a list of Documents documents = [] - for graph_document in graph_documents: - documents.append(convert_to_doc(graph_document)) + if self.expand_by_hops >= 0: + for response in responses: + elements = json.loads((response["path"]).serialize()) + for element in elements: + clean_element(element, self.embeddings_column) + response["path"] = elements + content = dumps(response["path"]) + documents.append(Document(page_content=content, metadata={})) + + else: + for response in responses: + documents.append(convert_to_doc(response)) + return documents diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py index 80e25ee..3ecd16e 100644 --- a/tests/integration/test_spanner_graph_retriever.py +++ b/tests/integration/test_spanner_graph_retriever.py @@ -157,7 +157,7 @@ def load_data(graph: SpannerGraphStore, suffix: str): class TestRetriever: - @pytest.fixture + @pytest.fixture(scope="module") def setup_db_load_data(self): graph, suffix = get_spanner_graph() load_data(graph, suffix) @@ -206,7 +206,7 @@ def test_spanner_graph_semantic_gql_retriever(self, setup_db_load_data): ) response = retriever.invoke("Where does Elias Thorne's sibling live?") assert response == [ - Document(metadata={}, page_content='{\n "location_id": "Capital City"\n}') + Document(metadata={}, page_content='{"location_id": "Capital City"}') ] def test_spanner_graph_vector_node_retriever_error(self, setup_db_load_data): @@ -230,8 +230,27 @@ def test_spanner_graph_vector_node_retriever(self, setup_db_load_data): label_expr="Person{}".format(suffix), return_properties_list=["name"], embeddings_column="desc_embedding", + top_k=1, k=1, ) response = retriever.invoke("Who lives in desert?") assert len(response) == 1 assert "Elias Thorne" in response[0].page_content + + def test_spanner_graph_vector_node_retriever_2(self, setup_db_load_data): + graph, suffix = setup_db_load_data + suffix = "_" + suffix + retriever = SpannerGraphNodeVectorRetriever.from_params( + graph_store=graph, + embedding_service=get_embedding(), + label_expr="Person{}".format(suffix), + expand_by_hops=1, + embeddings_column="desc_embedding", + top_k=1, + k=10, + ) + response = retriever.invoke( + "What do you know about the person who lives in desert?" + ) + assert len(response) == 4 + assert "Elias Thorne" in response[0].page_content From f8db780f8cad9c172442e80bb715fdeedabc5059 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 18 Dec 2024 23:45:24 +0000 Subject: [PATCH 27/40] Fix lint error --- src/langchain_google_spanner/graph_retriever.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 33bf773..e267b05 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -307,6 +307,9 @@ def _get_relevant_documents( schema = self.graph_store.get_schema graph_name = get_graph_name_from_schema(schema) node_variable = "node" + + if self.embedding_service is None: + raise ValueError("`embedding_service` cannot be None") query_embeddings = self.embedding_service.embed_query(question) distance_fn = get_distance_function(self.query_parameters.distance_strategy) From c0fdd6900fea9ce43862aab0d770ce3210147159 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Fri, 20 Dec 2024 19:05:50 +0000 Subject: [PATCH 28/40] Addressed review comments --- src/langchain_google_spanner/graph_retriever.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index e267b05..41125bf 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -143,10 +143,7 @@ def _get_relevant_documents( print(gql_query) # 2. Execute the gql query against spanner graph - try: - responses = self.graph_store.query(gql_query)[: self.k] - except Exception as e: - raise ValueError(str(e)) + responses = self.graph_store.query(gql_query)[: self.k] # 3. Transform the results into a list of Documents documents = [] @@ -234,10 +231,7 @@ def _get_relevant_documents( print(gql_query) # 2. Execute the gql query against spanner graph - try: - responses = self.graph_store.query(gql_query)[: self.k] - except Exception as e: - raise ValueError(str(e)) + responses = self.graph_store.query(gql_query)[: self.k] # 3. Transform the results into a list of Documents documents = [] @@ -336,7 +330,7 @@ def _get_relevant_documents( gql_query += """ RETURN node NEXT - MATCH p = (node) -[]-{{0,{}}} () + MATCH p = TRAIL (node) -[]-{{0,{}}} () RETURN SAFE_TO_JSON(p) as path """.format( self.expand_by_hops @@ -366,10 +360,7 @@ def _get_relevant_documents( print(gql_query) # 2. Execute the gql query against spanner graph - try: - responses = self.graph_store.query(gql_query)[: self.k] - except Exception as e: - raise ValueError(str(e)) + responses = self.graph_store.query(gql_query)[: self.k] # 3. Transform the results into a list of Documents documents = [] From 17dba7f466046bf761bbf61caf7fab3e2dc2bdd6 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Fri, 20 Dec 2024 19:13:40 +0000 Subject: [PATCH 29/40] Remove expansion query options --- .../graph_retriever.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 41125bf..b41f86c 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -242,8 +242,8 @@ def _get_relevant_documents( class SpannerGraphNodeVectorRetriever(BaseRetriever): """Retriever that does a vector search on nodes in a SpannerGraphStore. - If a graph expansion query is provided, it will be executed after the - initial vector search to expand the returned context. + If expand_by_hops is provided , the nodes (and edges) at a distance upto + the expand_by hops will also be returned. """ graph_store: SpannerGraphStore = Field(exclude=True) @@ -257,8 +257,6 @@ class SpannerGraphNodeVectorRetriever(BaseRetriever): query_parameters: QueryParameters = QueryParameters() top_k: int = 3 """Number of vector similarity matches to return""" - graph_expansion_query: str = "" - """GQL query to expand the returned context""" expand_by_hops: int = -1 """Number of hops to traverse to expand graph results""" k: int = 10 @@ -283,13 +281,11 @@ def __init__(self, **kwargs: Any) -> None: sum = 0 if self.return_properties_list: sum += 1 - if self.graph_expansion_query: - sum += 1 if self.expand_by_hops != -1: sum += 1 if sum != 1: raise ValueError( - "One and only one of `return_properties` or `graph_expansion_query` or `expand_by_hops` must be provided." + "One and only one of `return_properties` or `expand_by_hops` must be provided." ) def _get_relevant_documents( @@ -344,17 +340,9 @@ def _get_relevant_documents( """.format( return_properties ) - elif self.graph_expansion_query is not None: - gql_query += """ - RETURN node - NEXT - {} - """.format( - self.graph_expansion_query - ) else: raise ValueError( - "Either `return_properties` or `graph_expansion_query` must be provided." + "Either `return_properties` or `expand_by_hops` must be provided." ) print(gql_query) From ebe0b86621ef05bd94a9e09bacc22cbd179b8cd0 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Sat, 4 Jan 2025 00:20:12 +0000 Subject: [PATCH 30/40] Add backticks to property names --- src/langchain_google_spanner/graph_retriever.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index b41f86c..054ff35 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -333,7 +333,10 @@ def _get_relevant_documents( ) elif self.return_properties_list: return_properties = ",".join( - map(lambda x: node_variable + "." + x, self.return_properties_list) + map( + lambda x: node_variable + ".`" + x + "`", + self.return_properties_list, + ) ) gql_query += """ RETURN {} From eb30c876e602746b881b33a4dcccf1b620d738e0 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Mon, 6 Jan 2025 22:08:57 +0000 Subject: [PATCH 31/40] Change copyright year --- src/langchain_google_spanner/graph_retriever.py | 2 +- src/langchain_google_spanner/graph_utils.py | 2 +- tests/integration/test_spanner_graph_retriever.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 054ff35..d748de6 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/langchain_google_spanner/graph_utils.py b/src/langchain_google_spanner/graph_utils.py index 48411fc..8f3fe97 100644 --- a/src/langchain_google_spanner/graph_utils.py +++ b/src/langchain_google_spanner/graph_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py index 3ecd16e..edcd800 100644 --- a/tests/integration/test_spanner_graph_retriever.py +++ b/tests/integration/test_spanner_graph_retriever.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 441fa517c5b774e2f9826ba1fa4f28fe25078d3f Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 8 Jan 2025 21:32:08 +0000 Subject: [PATCH 32/40] Address review comments --- docs/graph_qa_chain.ipynb | 2 +- .../graph_retriever.py | 108 +++++++++--------- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/docs/graph_qa_chain.ipynb b/docs/graph_qa_chain.ipynb index 4d47f62..25c18a0 100644 --- a/docs/graph_qa_chain.ipynb +++ b/docs/graph_qa_chain.ipynb @@ -150,7 +150,7 @@ "source": [ "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", "\n", - "PROJECT_ID = \"\" # @param {type:\"string\"}\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", "\n", "# Set the project id\n", "!gcloud config set project {PROJECT_ID}\n", diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index d748de6..43a884f 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -51,55 +51,8 @@ def convert_to_doc(data: dict[str, Any]) -> Document: return Document(page_content=content, metadata={}) -def get_distance_function(distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - """Gets the vector distance function.""" - if distance_strategy == DistanceStrategy.COSINE: - return "COSINE_DISTANCE" - - return "EUCLIDEAN_DISTANCE" - - -def get_graph_name_from_schema(schema: str): - return json.loads(schema)["Name of graph"] - - -def duplicate_braces_in_string(text): - """Replaces single curly braces with double curly braces in a string. - - Args: - text: The input string. - - Returns: - The modified string with double curly braces. - """ - text = text.replace("{", "{{") - text = text.replace("}", "}}") - return text - - -def clean_element(element, embedding_column): - """Removes specified keys and embedding from properties in graph element. - - Args: - element: A dictionary representing element - - Returns: - A cleaned dictionary with the specified keys removed. - """ - - keys_to_remove = [ - "source_node_identifier", - "destination_node_identifier", - "identifier", - ] - for key in keys_to_remove: - if key in element: - del element[key] - - if "properties" in element and embedding_column in element["properties"]: - del element["properties"][embedding_column] - - return element +def get_graph_name_from_schema(schema: str) -> str: + return "`" + json.loads(schema)["Name of graph"] + "`" class SpannerGraphGQLRetriever(BaseRetriever): @@ -185,11 +138,30 @@ def from_params( **kwargs, ) + @staticmethod + def _duplicate_braces_in_string(text: str) -> str: + """Replaces single curly braces with double curly braces in a string. + + Args: + text: The input string. + + Returns: + The modified string with double curly braces. + """ + text = text.replace("{", "{{") + text = text.replace("}", "}}") + return text + def add_example(self, question: str, gql: str): if self.selector is None: raise ValueError("`selector` cannot be None") self.selector.add_example( - {"input": question, "query": duplicate_braces_in_string(gql)} + { + "input": question, + "query": SpannerGraphSemanticGQLRetriever._duplicate_braces_in_string( + gql + ), + } ) def _get_relevant_documents( @@ -288,6 +260,34 @@ def __init__(self, **kwargs: Any) -> None: "One and only one of `return_properties` or `expand_by_hops` must be provided." ) + @staticmethod + def _clean_element(element: dict[str, Any], embedding_column: str) -> None: + """Removes specified keys and embedding from properties in graph element. + + Args: + element: A dictionary representing element + """ + + keys_to_remove = [ + "source_node_identifier", + "destination_node_identifier", + "identifier", + ] + for key in keys_to_remove: + if key in element: + del element[key] + + if "properties" in element and embedding_column in element["properties"]: + del element["properties"][embedding_column] + + @staticmethod + def _get_distance_function(distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: + """Gets the vector distance function.""" + if distance_strategy == DistanceStrategy.COSINE: + return "COSINE_DISTANCE" + + return "EUCLIDEAN_DISTANCE" + def _get_relevant_documents( self, question: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: @@ -302,7 +302,9 @@ def _get_relevant_documents( raise ValueError("`embedding_service` cannot be None") query_embeddings = self.embedding_service.embed_query(question) - distance_fn = get_distance_function(self.query_parameters.distance_strategy) + distance_fn = SpannerGraphNodeVectorRetriever._get_distance_function( + self.query_parameters.distance_strategy + ) VECTOR_QUERY = """ GRAPH {graph_name} @@ -359,7 +361,9 @@ def _get_relevant_documents( for response in responses: elements = json.loads((response["path"]).serialize()) for element in elements: - clean_element(element, self.embeddings_column) + SpannerGraphNodeVectorRetriever._clean_element( + element, self.embeddings_column + ) response["path"] = elements content = dumps(response["path"]) documents.append(Document(page_content=content, metadata={})) From 91421faa0e242e90f41d497889f8572c41784147 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Sat, 18 Jan 2025 00:01:36 +0000 Subject: [PATCH 33/40] Rename the retrievers. Merge the semantic retriever with the gql retriever. --- src/langchain_google_spanner/__init__.py | 10 +- .../graph_retriever.py | 109 ++++++------------ .../test_spanner_graph_retriever.py | 15 ++- 3 files changed, 45 insertions(+), 89 deletions(-) diff --git a/src/langchain_google_spanner/__init__.py b/src/langchain_google_spanner/__init__.py index 84383ad..2b8b9c9 100644 --- a/src/langchain_google_spanner/__init__.py +++ b/src/langchain_google_spanner/__init__.py @@ -15,9 +15,8 @@ from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory from langchain_google_spanner.graph_qa import SpannerGraphQAChain from langchain_google_spanner.graph_retriever import ( - SpannerGraphGQLRetriever, - SpannerGraphNodeVectorRetriever, - SpannerGraphSemanticGQLRetriever, + SpannerGraphTextToGQLRetriever, + SpannerGraphVectorContextRetriever, ) from langchain_google_spanner.graph_store import SpannerGraphStore from langchain_google_spanner.vector_store import ( @@ -43,7 +42,6 @@ "SecondaryIndex", "QueryParameters", "DistanceStrategy", - "SpannerGraphGQLRetriever", - "SpannerGraphNodeVectorRetriever", - "SpannerGraphSemanticGQLRetriever", + "SpannerGraphTextToGQLRetriever", + "SpannerGraphVectorContextRetriever", ] diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 43a884f..e98ce25 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -55,61 +55,12 @@ def get_graph_name_from_schema(schema: str) -> str: return "`" + json.loads(schema)["Name of graph"] + "`" -class SpannerGraphGQLRetriever(BaseRetriever): +class SpannerGraphTextToGQLRetriever(BaseRetriever): """A Retriever that translates natural language queries to GQL and queries SpannerGraphStore using the GQL. Returns the documents retrieved as result. - """ - - graph_store: SpannerGraphStore = Field(exclude=True) - gql_chain: RunnableSequence - k: int = 10 - """Number of top results to return""" - - @classmethod - def from_params( - cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any - ) -> "SpannerGraphGQLRetriever": - if llm is None: - raise ValueError("`llm` cannot be none") - gql_chain: RunnableSequence = RunnableSequence( - GQL_GENERATION_PROMPT | llm | StrOutputParser() - ) - return cls(gql_chain=gql_chain, **kwargs) - - def _get_relevant_documents( - self, question: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - """Translate the natural language query to GQL, execute it, - and return the results as Documents. - """ - - # 1. Generate gql query from natural language query using LLM - gql_query = extract_gql( - self.gql_chain.invoke( - { - "question": question, - "schema": self.graph_store.get_schema, - } - ) - ) - print(gql_query) - - # 2. Execute the gql query against spanner graph - responses = self.graph_store.query(gql_query)[: self.k] - - # 3. Transform the results into a list of Documents - documents = [] - for response in responses: - documents.append(convert_to_doc(response)) - return documents - - -class SpannerGraphSemanticGQLRetriever(BaseRetriever): - """A Retriever that translates natural language queries to GQL and - and queries SpannerGraphStore to retrieve documents. It uses a semantic - similarity model to compare the input question to a set of examples to - generate the GQL query. + If examples are provided, it uses a semantic similarity model to compare the + input question to a set of examples to generate the GQL query. """ graph_store: SpannerGraphStore = Field(exclude=True) @@ -124,14 +75,16 @@ def from_params( llm: Optional[BaseLanguageModel] = None, embedding_service: Optional[Embeddings] = None, **kwargs: Any, - ) -> "SpannerGraphSemanticGQLRetriever": + ) -> "SpannerGraphTextToGQLRetriever": if llm is None: raise ValueError("`llm` cannot be none") - if embedding_service is None: - raise ValueError("`embedding_service` cannot be none") - selector = SemanticSimilarityExampleSelector.from_examples( - [], embedding_service, InMemoryVectorStore, k=2 - ) + # if embedding_service is None: + # raise ValueError("`embedding_service` cannot be none") + selector = None + if embedding_service is not None: + selector = SemanticSimilarityExampleSelector.from_examples( + [], embedding_service, InMemoryVectorStore, k=2 + ) return cls( llm=llm, selector=selector, @@ -158,7 +111,7 @@ def add_example(self, question: str, gql: str): self.selector.add_example( { "input": question, - "query": SpannerGraphSemanticGQLRetriever._duplicate_braces_in_string( + "query": SpannerGraphTextToGQLRetriever._duplicate_braces_in_string( gql ), } @@ -174,20 +127,26 @@ def _get_relevant_documents( if self.llm is None: raise ValueError("`llm` cannot be None") - if self.selector is None: - raise ValueError("`selector` cannot be None") + + # if self.selector is None: + # raise ValueError("`selector` cannot be None") # Define the prompt template - prompt = FewShotPromptTemplate( - example_selector=self.selector, - example_prompt=PromptTemplate.from_template( - "Question: {input}\nGQL Query: {query}" - ), - prefix=""" - Create an ISO GQL query for the question using the schema.""", - suffix=DEFAULT_GQL_TEMPLATE_PART1, - input_variables=["question", "schema"], - ) + prompt = None + if self.selector is None: + prompt = GQL_GENERATION_PROMPT + else: + # Define the prompt template + prompt = FewShotPromptTemplate( + example_selector=self.selector, + example_prompt=PromptTemplate.from_template( + "Question: {input}\nGQL Query: {query}" + ), + prefix=""" + Create an ISO GQL query for the question using the schema.""", + suffix=DEFAULT_GQL_TEMPLATE_PART1, + input_variables=["question", "schema"], + ) # Initialize the chain gql_chain = prompt | self.llm | StrOutputParser() @@ -212,7 +171,7 @@ def _get_relevant_documents( return documents -class SpannerGraphNodeVectorRetriever(BaseRetriever): +class SpannerGraphVectorContextRetriever(BaseRetriever): """Retriever that does a vector search on nodes in a SpannerGraphStore. If expand_by_hops is provided , the nodes (and edges) at a distance upto the expand_by hops will also be returned. @@ -237,7 +196,7 @@ class SpannerGraphNodeVectorRetriever(BaseRetriever): @classmethod def from_params( cls, embedding_service: Optional[Embeddings] = None, **kwargs: Any - ) -> "SpannerGraphNodeVectorRetriever": + ) -> "SpannerGraphVectorContextRetriever": if embedding_service is None: raise ValueError("`embedding_service` cannot be None") return cls( @@ -302,7 +261,7 @@ def _get_relevant_documents( raise ValueError("`embedding_service` cannot be None") query_embeddings = self.embedding_service.embed_query(question) - distance_fn = SpannerGraphNodeVectorRetriever._get_distance_function( + distance_fn = SpannerGraphVectorContextRetriever._get_distance_function( self.query_parameters.distance_strategy ) @@ -361,7 +320,7 @@ def _get_relevant_documents( for response in responses: elements = json.loads((response["path"]).serialize()) for element in elements: - SpannerGraphNodeVectorRetriever._clean_element( + SpannerGraphVectorContextRetriever._clean_element( element, self.embeddings_column ) response["path"] = elements diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py index edcd800..1a34702 100644 --- a/tests/integration/test_spanner_graph_retriever.py +++ b/tests/integration/test_spanner_graph_retriever.py @@ -23,9 +23,8 @@ from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings from langchain_google_spanner.graph_retriever import ( - SpannerGraphGQLRetriever, - SpannerGraphNodeVectorRetriever, - SpannerGraphSemanticGQLRetriever, + SpannerGraphTextToGQLRetriever, + SpannerGraphVectorContextRetriever, ) from langchain_google_spanner.graph_store import SpannerGraphStore @@ -167,7 +166,7 @@ def setup_db_load_data(self): def test_spanner_graph_gql_retriever(self, setup_db_load_data): graph, suffix = setup_db_load_data - retriever = SpannerGraphGQLRetriever.from_params( + retriever = SpannerGraphTextToGQLRetriever.from_params( graph_store=graph, llm=get_llm(), ) @@ -179,7 +178,7 @@ def test_spanner_graph_gql_retriever(self, setup_db_load_data): def test_spanner_graph_semantic_gql_retriever(self, setup_db_load_data): graph, suffix = setup_db_load_data suffix = "_" + suffix - retriever = SpannerGraphSemanticGQLRetriever.from_params( + retriever = SpannerGraphTextToGQLRetriever.from_params( graph_store=graph, llm=get_llm(), embedding_service=get_embedding(), @@ -213,7 +212,7 @@ def test_spanner_graph_vector_node_retriever_error(self, setup_db_load_data): with pytest.raises(ValueError): graph, suffix = setup_db_load_data suffix = "_" + suffix - SpannerGraphNodeVectorRetriever.from_params( + SpannerGraphVectorContextRetriever.from_params( graph_store=graph, embedding_service=get_embedding(), label_expr="Person{}".format(suffix), @@ -224,7 +223,7 @@ def test_spanner_graph_vector_node_retriever_error(self, setup_db_load_data): def test_spanner_graph_vector_node_retriever(self, setup_db_load_data): graph, suffix = setup_db_load_data suffix = "_" + suffix - retriever = SpannerGraphNodeVectorRetriever.from_params( + retriever = SpannerGraphVectorContextRetriever.from_params( graph_store=graph, embedding_service=get_embedding(), label_expr="Person{}".format(suffix), @@ -240,7 +239,7 @@ def test_spanner_graph_vector_node_retriever(self, setup_db_load_data): def test_spanner_graph_vector_node_retriever_2(self, setup_db_load_data): graph, suffix = setup_db_load_data suffix = "_" + suffix - retriever = SpannerGraphNodeVectorRetriever.from_params( + retriever = SpannerGraphVectorContextRetriever.from_params( graph_store=graph, embedding_service=get_embedding(), label_expr="Person{}".format(suffix), From 3d9b6f61386aa1b06570ee859d4d324baa919723 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Sat, 18 Jan 2025 01:11:14 +0000 Subject: [PATCH 34/40] Fixed lint errors --- src/langchain_google_spanner/graph_retriever.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index e98ce25..e116418 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -131,13 +131,15 @@ def _get_relevant_documents( # if self.selector is None: # raise ValueError("`selector` cannot be None") - # Define the prompt template - prompt = None + gql_chain: RunnableSequence if self.selector is None: - prompt = GQL_GENERATION_PROMPT + generic_prompt = GQL_GENERATION_PROMPT + gql_chain = RunnableSequence( + generic_prompt | self.llm | StrOutputParser() + ) else: # Define the prompt template - prompt = FewShotPromptTemplate( + few_shot_prompt = FewShotPromptTemplate( example_selector=self.selector, example_prompt=PromptTemplate.from_template( "Question: {input}\nGQL Query: {query}" @@ -147,9 +149,10 @@ def _get_relevant_documents( suffix=DEFAULT_GQL_TEMPLATE_PART1, input_variables=["question", "schema"], ) + gql_chain = RunnableSequence ( + few_shot_prompt | self.llm | StrOutputParser() + ) - # Initialize the chain - gql_chain = prompt | self.llm | StrOutputParser() # 1. Generate gql query from natural language query using LLM gql_query = extract_gql( gql_chain.invoke( From 2a3b63eff19773956ab0e12f7a94bccc98fc1a90 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 22 Jan 2025 00:39:24 +0000 Subject: [PATCH 35/40] Change vertex ai versionto latest --- pyproject.toml | 2 +- src/langchain_google_spanner/graph_retriever.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b5c5131..597a6f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ test = [ "pytest==8.3.3", "pytest-asyncio==0.24.0", "pytest-cov==5.0.0", - "langchain_google_vertexai==1.0.10" + "langchain_google_vertexai==2.0.8" ] [build-system] diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index e116418..0c7e55a 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -78,8 +78,6 @@ def from_params( ) -> "SpannerGraphTextToGQLRetriever": if llm is None: raise ValueError("`llm` cannot be none") - # if embedding_service is None: - # raise ValueError("`embedding_service` cannot be none") selector = None if embedding_service is not None: selector = SemanticSimilarityExampleSelector.from_examples( @@ -128,9 +126,6 @@ def _get_relevant_documents( if self.llm is None: raise ValueError("`llm` cannot be None") - # if self.selector is None: - # raise ValueError("`selector` cannot be None") - gql_chain: RunnableSequence if self.selector is None: generic_prompt = GQL_GENERATION_PROMPT From 577f511e8e384ed59353abfc96ab1b2fdd6ff4ee Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Wed, 22 Jan 2025 01:14:49 +0000 Subject: [PATCH 36/40] Fix lint errors --- src/langchain_google_spanner/graph_retriever.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 0c7e55a..7f0bb0f 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -129,9 +129,7 @@ def _get_relevant_documents( gql_chain: RunnableSequence if self.selector is None: generic_prompt = GQL_GENERATION_PROMPT - gql_chain = RunnableSequence( - generic_prompt | self.llm | StrOutputParser() - ) + gql_chain = RunnableSequence(generic_prompt | self.llm | StrOutputParser()) else: # Define the prompt template few_shot_prompt = FewShotPromptTemplate( @@ -144,9 +142,7 @@ def _get_relevant_documents( suffix=DEFAULT_GQL_TEMPLATE_PART1, input_variables=["question", "schema"], ) - gql_chain = RunnableSequence ( - few_shot_prompt | self.llm | StrOutputParser() - ) + gql_chain = RunnableSequence(few_shot_prompt | self.llm | StrOutputParser()) # 1. Generate gql query from natural language query using LLM gql_query = extract_gql( From 82d427af278ebdb6ee93b023c1d505e935bb49c7 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Fri, 24 Jan 2025 20:31:09 +0000 Subject: [PATCH 37/40] Add documentation. Fixes the case where expands_by_hops is 0 --- README.rst | 45 +++++++++++++++++++ .../graph_retriever.py | 8 +++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index cb047dd..7051354 100644 --- a/README.rst +++ b/README.rst @@ -179,6 +179,51 @@ See the full `Spanner Graph QA Chain`_ tutorial. .. _`Spanner Graph QA Chain`: https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_qa_chain.ipynb +Spanner Graph Retrievers Usage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use ``SpannerGraphTextToGQLRetriever`` to translate natural language question to GQL and query SpannerGraphStore. + +.. code:: python + from langchain_google_spanner import SpannerGraphStore, SpannerGraphTextToGQLRetriever + from langchain_google_vertexai import ChatVertexAI + + + graph = SpannerGraphStore( + instance_id="my-instance", + database_id="my-database", + graph_name="my_graph", + ) + llm = ChatVertexAI() + retriever = SpannerGraphTextToGQLRetriever.from_params( + graph_store=graph, + llm=llm + ) + retriever.invoke("Where does Elias Thorne's sibling live?") + +Use ``SpannerGraphVectorContextRetriever`` to perform vector search on nodes in a SpannerGraphStore. If expand_by_hops is provided, the nodes and edges at a distance upto the expand_by_hops from the nodes found in the vector search will also be returned. + +.. code:: python + + from langchain_google_spanner import SpannerGraphStore, SpannerGraphQAChain + from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings + + + graph = SpannerGraphStore( + instance_id="my-instance", + database_id="my-database", + graph_name="my_graph", + ) + embedding_service = VertexAIEmbeddings(model_name="text-embedding-004") + retriever = SpannerGraphVectorContextRetriever.from_params( + graph_store=graph, + embedding_service=embedding_service, + label_expr="Person", + embeddings_column="embeddings", + top_k=1, + expand_by_hops=1, + ) + retriever.invoke("Who lives in desert?") Contributions ~~~~~~~~~~~~~ diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 7f0bb0f..d2ec5b0 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -168,7 +168,7 @@ def _get_relevant_documents( class SpannerGraphVectorContextRetriever(BaseRetriever): """Retriever that does a vector search on nodes in a SpannerGraphStore. If expand_by_hops is provided , the nodes (and edges) at a distance upto - the expand_by hops will also be returned. + the expand_by_hops will also be returned. """ graph_store: SpannerGraphStore = Field(exclude=True) @@ -277,7 +277,11 @@ def _get_relevant_documents( top_k=self.top_k, ) - if self.expand_by_hops >= 0: + if self.expand_by_hops == 0: + gql_query += """ + RETURN SAFE_TO_JSON(node) as path + """ + elif self.expand_by_hops > 0: gql_query += """ RETURN node NEXT From 685d1f1403835ba85f2a3729c23e439c4c8c301d Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Fri, 24 Jan 2025 20:56:59 +0000 Subject: [PATCH 38/40] Add unit test for expand_by_hops=0 --- README.rst | 5 +++-- .../graph_retriever.py | 2 +- .../test_spanner_graph_retriever.py | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 7051354..4ee1912 100644 --- a/README.rst +++ b/README.rst @@ -201,11 +201,11 @@ Use ``SpannerGraphTextToGQLRetriever`` to translate natural language question to ) retriever.invoke("Where does Elias Thorne's sibling live?") -Use ``SpannerGraphVectorContextRetriever`` to perform vector search on nodes in a SpannerGraphStore. If expand_by_hops is provided, the nodes and edges at a distance upto the expand_by_hops from the nodes found in the vector search will also be returned. +Use ``SpannerGraphVectorContextRetriever`` to perform vector search on embeddings that are stored in the nodes in a SpannerGraphStore. If expand_by_hops is provided, the nodes and edges at a distance upto the expand_by_hops from the nodes found in the vector search will also be returned. .. code:: python - from langchain_google_spanner import SpannerGraphStore, SpannerGraphQAChain + from langchain_google_spanner import SpannerGraphStore, SpannerGraphVectorContextRetriever from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings @@ -225,6 +225,7 @@ Use ``SpannerGraphVectorContextRetriever`` to perform vector search on nodes in ) retriever.invoke("Who lives in desert?") + Contributions ~~~~~~~~~~~~~ diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index d2ec5b0..8f0f50a 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -314,7 +314,7 @@ def _get_relevant_documents( # 3. Transform the results into a list of Documents documents = [] - if self.expand_by_hops >= 0: + if self.expand_by_hops > 0: for response in responses: elements = json.loads((response["path"]).serialize()) for element in elements: diff --git a/tests/integration/test_spanner_graph_retriever.py b/tests/integration/test_spanner_graph_retriever.py index 1a34702..b3fa0f6 100644 --- a/tests/integration/test_spanner_graph_retriever.py +++ b/tests/integration/test_spanner_graph_retriever.py @@ -253,3 +253,21 @@ def test_spanner_graph_vector_node_retriever_2(self, setup_db_load_data): ) assert len(response) == 4 assert "Elias Thorne" in response[0].page_content + + def test_spanner_graph_vector_node_retriever_0_hops(self, setup_db_load_data): + graph, suffix = setup_db_load_data + suffix = "_" + suffix + retriever = SpannerGraphVectorContextRetriever.from_params( + graph_store=graph, + embedding_service=get_embedding(), + label_expr="Person{}".format(suffix), + expand_by_hops=0, + embeddings_column="desc_embedding", + top_k=1, + k=10, + ) + response = retriever.invoke( + "What do you know about the person who lives in desert?" + ) + assert len(response) == 1 + assert "Elias Thorne" in response[0].page_content From 1aeb21c3716d1e63a71c81c1fa38a4ad6b6ad685 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Fri, 24 Jan 2025 21:00:42 +0000 Subject: [PATCH 39/40] Fix formatting for documentation --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index 4ee1912..48fafdd 100644 --- a/README.rst +++ b/README.rst @@ -185,6 +185,7 @@ Spanner Graph Retrievers Usage Use ``SpannerGraphTextToGQLRetriever`` to translate natural language question to GQL and query SpannerGraphStore. .. code:: python + from langchain_google_spanner import SpannerGraphStore, SpannerGraphTextToGQLRetriever from langchain_google_vertexai import ChatVertexAI From 905cf65ff0ccac97fb59208fac6af413f1a67904 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Sat, 25 Jan 2025 01:31:54 +0000 Subject: [PATCH 40/40] Addressed review comments --- .../graph_retriever.py | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 8f0f50a..d746eb0 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -89,8 +89,7 @@ def from_params( **kwargs, ) - @staticmethod - def _duplicate_braces_in_string(text: str) -> str: + def __duplicate_braces_in_string(self, text: str) -> str: """Replaces single curly braces with double curly braces in a string. Args: @@ -109,9 +108,7 @@ def add_example(self, question: str, gql: str): self.selector.add_example( { "input": question, - "query": SpannerGraphTextToGQLRetriever._duplicate_braces_in_string( - gql - ), + "query": self.__duplicate_braces_in_string(gql), } ) @@ -153,7 +150,6 @@ def _get_relevant_documents( } ) ) - print(gql_query) # 2. Execute the gql query against spanner graph responses = self.graph_store.query(gql_query)[: self.k] @@ -189,7 +185,7 @@ class SpannerGraphVectorContextRetriever(BaseRetriever): @classmethod def from_params( - cls, embedding_service: Optional[Embeddings] = None, **kwargs: Any + cls, embedding_service: Embeddings, **kwargs: Any ) -> "SpannerGraphVectorContextRetriever": if embedding_service is None: raise ValueError("`embedding_service` cannot be None") @@ -213,8 +209,7 @@ def __init__(self, **kwargs: Any) -> None: "One and only one of `return_properties` or `expand_by_hops` must be provided." ) - @staticmethod - def _clean_element(element: dict[str, Any], embedding_column: str) -> None: + def __clean_element(self, element: dict[str, Any], embedding_column: str) -> None: """Removes specified keys and embedding from properties in graph element. Args: @@ -233,8 +228,9 @@ def _clean_element(element: dict[str, Any], embedding_column: str) -> None: if "properties" in element and embedding_column in element["properties"]: del element["properties"][embedding_column] - @staticmethod - def _get_distance_function(distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: + def __get_distance_function( + self, distance_strategy=DistanceStrategy.EUCLIDEIAN + ) -> str: """Gets the vector distance function.""" if distance_strategy == DistanceStrategy.COSINE: return "COSINE_DISTANCE" @@ -255,7 +251,7 @@ def _get_relevant_documents( raise ValueError("`embedding_service` cannot be None") query_embeddings = self.embedding_service.embed_query(question) - distance_fn = SpannerGraphVectorContextRetriever._get_distance_function( + distance_fn = self.__get_distance_function( self.query_parameters.distance_strategy ) @@ -307,8 +303,6 @@ def _get_relevant_documents( "Either `return_properties` or `expand_by_hops` must be provided." ) - print(gql_query) - # 2. Execute the gql query against spanner graph responses = self.graph_store.query(gql_query)[: self.k] @@ -318,9 +312,7 @@ def _get_relevant_documents( for response in responses: elements = json.loads((response["path"]).serialize()) for element in elements: - SpannerGraphVectorContextRetriever._clean_element( - element, self.embeddings_column - ) + self.__clean_element(element, self.embeddings_column) response["path"] = elements content = dumps(response["path"]) documents.append(Document(page_content=content, metadata={}))