Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(graph): Add Custom Retrievers for Spanner Graph RAG. #122

Merged
merged 44 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
80b6f19
Add Spanner Graph QA Chain
Nov 26, 2024
d8e6640
Formatted notebook. Added copyright message to prompts file.
Nov 27, 2024
7a27522
Add missing imports for random graph name
Nov 27, 2024
e0bbdf1
Make input table name randomized in integration tests to avoid name c…
Nov 28, 2024
7f578bc
Provide timeout to graph cleanup
Dec 2, 2024
3650d89
Make default timeout of 300 secs for ddl application
Dec 3, 2024
63b3508
Increase timeout of integration test
Dec 3, 2024
b9f718c
Change integration test timeout
Dec 3, 2024
95768fc
Minor formatting fixes
Dec 3, 2024
638cdeb
Make the ddl operations test fixture scoped for the module
Dec 3, 2024
88aa4c1
Addressed review comments
Dec 4, 2024
a211728
Addressed a few other review comments.
Dec 4, 2024
80338fc
Remove unused function
Dec 4, 2024
c8799ea
fix type check errors
Dec 4, 2024
0d358aa
Addressed review comments
Dec 4, 2024
340eadc
Addressed review comments
Dec 5, 2024
1439f04
Clear default project id from notebook
Dec 5, 2024
449ec35
Merge branch 'main' into graphqachain
amullick-git Dec 5, 2024
06d0489
Add import statement for SpanerGraphQAChain to notebook
Dec 5, 2024
81a35e3
Merge branch 'graphqachain' of https://github.com/amullick-git/langch…
Dec 5, 2024
a086ff6
Add retrievers for Spanner Graph RAG
Dec 12, 2024
1ae9fec
Add licence headers
Dec 12, 2024
704ed32
Fix DATABASE name key
Dec 12, 2024
11f674e
Fix lint error on import ordering
Dec 12, 2024
1daf9e6
Fix lint errors
Dec 13, 2024
c46ee65
Few minor changes to the SpannerGraphNodeVectorRetriever
Dec 16, 2024
67d9c5e
Fix lint error
Dec 16, 2024
b3e4e3c
Add an option to expand context graph by hops
Dec 18, 2024
f8db780
Fix lint error
Dec 18, 2024
c0fdd69
Addressed review comments
Dec 20, 2024
17dba7f
Remove expansion query options
Dec 20, 2024
ebe0b86
Add backticks to property names
Jan 4, 2025
eb30c87
Change copyright year
Jan 6, 2025
441fa51
Address review comments
Jan 8, 2025
91421fa
Rename the retrievers. Merge the semantic retriever with the gql retr…
Jan 18, 2025
3d9b6f6
Fixed lint errors
Jan 18, 2025
dc4f993
Merge branch 'main' into graphqachain
amullick-git Jan 18, 2025
2a3b63e
Change vertex ai versionto latest
Jan 22, 2025
d6e3173
Merge branch 'graphqachain' of https://github.com/amullick-git/langch…
Jan 22, 2025
577f511
Fix lint errors
Jan 22, 2025
82d427a
Add documentation. Fixes the case where expands_by_hops is 0
Jan 24, 2025
685d1f1
Add unit test for expand_by_hops=0
Jan 24, 2025
1aeb21c
Fix formatting for documentation
Jan 24, 2025
905cf65
Addressed review comments
Jan 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,53 @@ See the full `Spanner Graph QA Chain`_ tutorial.

.. _`Spanner Graph QA Chain`: https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_qa_chain.ipynb

Spanner Graph Retrievers Usage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Use ``SpannerGraphTextToGQLRetriever`` to translate natural language question to GQL and query SpannerGraphStore.

.. code:: python

from langchain_google_spanner import SpannerGraphStore, SpannerGraphTextToGQLRetriever
from langchain_google_vertexai import ChatVertexAI


graph = SpannerGraphStore(
instance_id="my-instance",
database_id="my-database",
graph_name="my_graph",
)
llm = ChatVertexAI()
retriever = SpannerGraphTextToGQLRetriever.from_params(
graph_store=graph,
llm=llm
)
retriever.invoke("Where does Elias Thorne's sibling live?")

Use ``SpannerGraphVectorContextRetriever`` to perform vector search on embeddings that are stored in the nodes in a SpannerGraphStore. If expand_by_hops is provided, the nodes and edges at a distance upto the expand_by_hops from the nodes found in the vector search will also be returned.

.. code:: python

from langchain_google_spanner import SpannerGraphStore, SpannerGraphVectorContextRetriever
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings


graph = SpannerGraphStore(
instance_id="my-instance",
database_id="my-database",
graph_name="my_graph",
)
embedding_service = VertexAIEmbeddings(model_name="text-embedding-004")
retriever = SpannerGraphVectorContextRetriever.from_params(
graph_store=graph,
embedding_service=embedding_service,
label_expr="Person",
embeddings_column="embeddings",
top_k=1,
expand_by_hops=1,
)
retriever.invoke("Who lives in desert?")


Contributions
~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/graph_qa_chain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"source": [
"# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n",
"\n",
"PROJECT_ID = \"\" # @param {type:\"string\"}\n",
"PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n",
"\n",
"# Set the project id\n",
"!gcloud config set project {PROJECT_ID}\n",
Expand Down
6 changes: 6 additions & 0 deletions src/langchain_google_spanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory
from langchain_google_spanner.graph_qa import SpannerGraphQAChain
from langchain_google_spanner.graph_retriever import (
SpannerGraphTextToGQLRetriever,
SpannerGraphVectorContextRetriever,
)
from langchain_google_spanner.graph_store import SpannerGraphStore
from langchain_google_spanner.vector_store import (
DistanceStrategy,
Expand All @@ -38,4 +42,6 @@
"SecondaryIndex",
"QueryParameters",
"DistanceStrategy",
"SpannerGraphTextToGQLRetriever",
"SpannerGraphVectorContextRetriever",
]
46 changes: 1 addition & 45 deletions src/langchain_google_spanner/graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import re
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
Expand All @@ -28,6 +27,7 @@

from langchain_google_spanner.graph_store import SpannerGraphStore

from .graph_utils import extract_gql, fix_gql_syntax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to note this as a breaking change. I can add that manually to the change log when releasing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Thank you for point this out.

from .prompts import (
DEFAULT_GQL_FIX_TEMPLATE,
DEFAULT_GQL_TEMPLATE,
Expand Down Expand Up @@ -71,50 +71,6 @@ class VerifyGqlOutput(BaseModel):
INTERMEDIATE_STEPS_KEY = "intermediate_steps"


def fix_gql_syntax(query: str) -> str:
"""Fixes the syntax of a GQL query.
Example 1:
Input:
MATCH (p:paper {id: 0})-[c:cites*8]->(p2:paper)
Output:
MATCH (p:paper {id: 0})-[c:cites]->{8}(p2:paper)
Example 2:
Input:
MATCH (p:paper {id: 0})-[c:cites*1..8]->(p2:paper)
Output:
MATCH (p:paper {id: 0})-[c:cites]->{1:8}(p2:paper)

Args:
query: The input GQL query.

Returns:
Possibly modified GQL query.
"""

query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]->", r"-[\1:\2]->{\3,\4}", query)
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]->", r"-[\1:\2]->{\3}", query)
query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"<-[\1:\2]-{\3,\4}", query)
query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\]-", r"<-[\1:\2]-{\3}", query)
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"-[\1:\2]-{\3,\4}", query)
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]-", r"-[\1:\2]-{\3}", query)
return query


def extract_gql(text: str) -> str:
"""Extract GQL query from a text.

Args:
text: Text to extract GQL query from.

Returns:
GQL query extracted from the text.
"""
pattern = r"```(.*?)```"
matches = re.findall(pattern, text, re.DOTALL)
query = matches[0] if matches else text
return fix_gql_syntax(query)


class SpannerGraphQAChain(Chain):
"""Chain for question-answering against a Spanner Graph database by
generating GQL statements from natural language questions.
Expand Down
Loading
Loading