Skip to content

Commit

Permalink
feat(graph): Add Custom Retrievers for Spanner Graph RAG. (#122)
Browse files Browse the repository at this point in the history
* Add Spanner Graph QA Chain

* Formatted notebook. Added copyright message to prompts file.

* Add missing imports for random graph name

* Make input table name randomized in integration tests to avoid name collision for tests running parallely from different python environments

* Provide timeout to graph cleanup

* Make default timeout of 300 secs for ddl application

* Increase timeout of integration test

* Change integration test timeout

* Minor formatting fixes

* Make the ddl operations test fixture scoped for the module

* Addressed review comments

* Addressed a few other review comments.

* Remove unused function

* fix type check errors

* Addressed review comments

* Addressed review comments

* Clear default project id from notebook

* Add import statement for SpanerGraphQAChain to notebook

* Add retrievers for Spanner Graph RAG

* Add licence headers

* Fix DATABASE name key

* Fix lint error on import ordering

* Fix lint errors

* Few minor changes to the SpannerGraphNodeVectorRetriever

* Fix lint error

* Add an option to expand context graph by hops

* Fix lint error

* Addressed review comments

* Remove expansion query options

* Add backticks to property names

* Change copyright year

* Address review comments

* Rename the retrievers. Merge the semantic retriever with the gql retriever.

* Fixed lint errors

* Change vertex ai versionto latest

* Fix lint errors

* Add documentation. Fixes the case where expands_by_hops is 0

* Add unit test for expand_by_hops=0

* Fix formatting for documentation

* Addressed review comments

---------

Co-authored-by: Amarnath Mullick <[email protected]>
  • Loading branch information
amullick-git and Amarnath Mullick authored Jan 29, 2025
1 parent fd788d8 commit bf2903a
Show file tree
Hide file tree
Showing 7 changed files with 714 additions and 46 deletions.
47 changes: 47 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,53 @@ See the full `Spanner Graph QA Chain`_ tutorial.

.. _`Spanner Graph QA Chain`: https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_qa_chain.ipynb

Spanner Graph Retrievers Usage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Use ``SpannerGraphTextToGQLRetriever`` to translate natural language question to GQL and query SpannerGraphStore.

.. code:: python
from langchain_google_spanner import SpannerGraphStore, SpannerGraphTextToGQLRetriever
from langchain_google_vertexai import ChatVertexAI
graph = SpannerGraphStore(
instance_id="my-instance",
database_id="my-database",
graph_name="my_graph",
)
llm = ChatVertexAI()
retriever = SpannerGraphTextToGQLRetriever.from_params(
graph_store=graph,
llm=llm
)
retriever.invoke("Where does Elias Thorne's sibling live?")
Use ``SpannerGraphVectorContextRetriever`` to perform vector search on embeddings that are stored in the nodes in a SpannerGraphStore. If expand_by_hops is provided, the nodes and edges at a distance upto the expand_by_hops from the nodes found in the vector search will also be returned.

.. code:: python
from langchain_google_spanner import SpannerGraphStore, SpannerGraphVectorContextRetriever
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
graph = SpannerGraphStore(
instance_id="my-instance",
database_id="my-database",
graph_name="my_graph",
)
embedding_service = VertexAIEmbeddings(model_name="text-embedding-004")
retriever = SpannerGraphVectorContextRetriever.from_params(
graph_store=graph,
embedding_service=embedding_service,
label_expr="Person",
embeddings_column="embeddings",
top_k=1,
expand_by_hops=1,
)
retriever.invoke("Who lives in desert?")
Contributions
~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/graph_qa_chain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"source": [
"# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n",
"\n",
"PROJECT_ID = \"\" # @param {type:\"string\"}\n",
"PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n",
"\n",
"# Set the project id\n",
"!gcloud config set project {PROJECT_ID}\n",
Expand Down
6 changes: 6 additions & 0 deletions src/langchain_google_spanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory
from langchain_google_spanner.graph_qa import SpannerGraphQAChain
from langchain_google_spanner.graph_retriever import (
SpannerGraphTextToGQLRetriever,
SpannerGraphVectorContextRetriever,
)
from langchain_google_spanner.graph_store import SpannerGraphStore
from langchain_google_spanner.vector_store import (
DistanceStrategy,
Expand All @@ -38,4 +42,6 @@
"SecondaryIndex",
"QueryParameters",
"DistanceStrategy",
"SpannerGraphTextToGQLRetriever",
"SpannerGraphVectorContextRetriever",
]
46 changes: 1 addition & 45 deletions src/langchain_google_spanner/graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import re
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
Expand All @@ -28,6 +27,7 @@

from langchain_google_spanner.graph_store import SpannerGraphStore

from .graph_utils import extract_gql, fix_gql_syntax
from .prompts import (
DEFAULT_GQL_FIX_TEMPLATE,
DEFAULT_GQL_TEMPLATE,
Expand Down Expand Up @@ -71,50 +71,6 @@ class VerifyGqlOutput(BaseModel):
INTERMEDIATE_STEPS_KEY = "intermediate_steps"


def fix_gql_syntax(query: str) -> str:
"""Fixes the syntax of a GQL query.
Example 1:
Input:
MATCH (p:paper {id: 0})-[c:cites*8]->(p2:paper)
Output:
MATCH (p:paper {id: 0})-[c:cites]->{8}(p2:paper)
Example 2:
Input:
MATCH (p:paper {id: 0})-[c:cites*1..8]->(p2:paper)
Output:
MATCH (p:paper {id: 0})-[c:cites]->{1:8}(p2:paper)
Args:
query: The input GQL query.
Returns:
Possibly modified GQL query.
"""

query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]->", r"-[\1:\2]->{\3,\4}", query)
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]->", r"-[\1:\2]->{\3}", query)
query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"<-[\1:\2]-{\3,\4}", query)
query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\]-", r"<-[\1:\2]-{\3}", query)
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"-[\1:\2]-{\3,\4}", query)
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]-", r"-[\1:\2]-{\3}", query)
return query


def extract_gql(text: str) -> str:
"""Extract GQL query from a text.
Args:
text: Text to extract GQL query from.
Returns:
GQL query extracted from the text.
"""
pattern = r"```(.*?)```"
matches = re.findall(pattern, text, re.DOTALL)
query = matches[0] if matches else text
return fix_gql_syntax(query)


class SpannerGraphQAChain(Chain):
"""Chain for question-answering against a Spanner Graph database by
generating GQL statements from natural language questions.
Expand Down
Loading

0 comments on commit bf2903a

Please sign in to comment.