Skip to content

Commit

Permalink
add user agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Feb 21, 2024
1 parent 5cea61c commit 8315d9d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
31 changes: 23 additions & 8 deletions src/langchain_google_spanner/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
import datetime
import json
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Union
from typing import Any, Dict, Iterator, List, Optional, Union

from google.cloud.spanner import Client, KeySet # type: ignore
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect # type: ignore
from google.cloud.spanner_v1.data_types import JsonObject # type: ignore
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document

from .version import __version__

USER_AGENT_LOADER = "langchain-google-spanner-python:document_loader" + __version__

OPERATION_TIMEOUT_SECONDS = 240
MUTATION_BATCH_SIZE = 1000

Expand All @@ -37,6 +41,17 @@ class Column:
nullable: bool = True


def client_with_user_agent(client: Optional[Client], user_agent: str) -> Client:
if not client:
client = Client()
client_agent = client._client_info.user_agent
if not client_agent:
client._client_info.user_agent = user_agent
elif user_agent not in client_agent:
client._client_info.user_agent = " ".join([client_agent, user_agent])
return client


def _load_row_to_doc(
format: str,
content_columns: List[str],
Expand Down Expand Up @@ -123,21 +138,20 @@ def __init__(
instance_id: str,
database_id: str,
query: str,
client: Client = Client(),
content_columns: List[str] = [],
metadata_columns: List[str] = [],
format: str = "text",
databoost: bool = False,
metadata_json_column: str = METADATA_COL_NAME,
staleness: Union[float, datetime.datetime] = 0.0,
client: Optional[Client] = None,
):
"""Initialize Spanner document loader.
Args:
instance_id: The Spanner instance to load data from.
database_id: The Spanner database to load data from.
query: A GoogleSQL or PostgreSQL query. Users must match dialect to their database.
client: The connection object to use. This can be used to customize project id and credentials.
content_columns: The list of column(s) or field(s) to use for a Document's page content.
Page content is the default field for embeddings generation.
metadata_columns: The list of column(s) or field(s) to use for metadata.
Expand All @@ -146,6 +160,7 @@ def __init__(
databoost: Use data boost on read. Note: needs extra IAM permissions and higher cost.
metadata_json_column: The name of the JSON column to use as the metadata's base dictionary.
staleness: The time bound for stale read. Takes either a datetime or float.
client: The connection object to use. This can be used to customize project id and credentials.
"""
self.instance_id = instance_id
self.database_id = database_id
Expand All @@ -158,7 +173,7 @@ def __init__(
if self.format not in formats:
raise Exception("Use one of 'text', 'JSON', 'YAML', 'CSV'.")
self.databoost = databoost
self.client = client
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
self.staleness = staleness
instance = self.client.instance(instance_id)
if not instance.exists():
Expand Down Expand Up @@ -230,30 +245,30 @@ def __init__(
instance_id: str,
database_id: str,
table_name: str,
client: Client = Client(),
content_column: str = CONTENT_COL_NAME,
metadata_columns: List[str] = [],
metadata_json_column: str = METADATA_COL_NAME,
client: Optional[Client] = None,
):
"""Initialize Spanner document saver.
Args:
instance_id: The Spanner instance to load data to.
database_id: The Spanner database to load data to.
table_name: The table name to load data to.
client: The connection object to use. This can be used to customized project id and credentials.
content_column: The name of the content column. Defaulted to the first column.
metadata_columns: This is for user to opt-in a selection of columns to use. Defaulted to use
all columns.
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
client: The connection object to use. This can be used to customized project id and credentials.
"""
self.instance_id = instance_id
self.database_id = database_id
self.table_name = table_name
self.content_column = content_column
self.metadata_columns = metadata_columns
self.metadata_json_column = metadata_json_column
self.client = client
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
instance = self.client.instance(instance_id)
if not instance.exists():
raise Exception("Instance doesn't exist.")
Expand Down Expand Up @@ -349,7 +364,7 @@ def init_document_table(
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
"""
primary_key = primary_key or content_column
client = Client()
client = client_with_user_agent(None, USER_AGENT_LOADER)
metadata_json_column = metadata_json_column if store_metadata else ""
instance = client.instance(instance_id)
if not instance.exists():
Expand Down
34 changes: 18 additions & 16 deletions tests/test_spanner_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def setup_database(self, client):
instance_id,
google_database,
table_name,
client,
client=client,
content_column="product_id",
metadata_columns=["product_name", "description", "price", "dummy_col"],
)
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_loader_custom_format_error(self, client):
instance_id,
google_database,
query,
client,
client=client,
format="NOT_A_FORMAT",
)
docs = loader.load()
Expand All @@ -367,7 +367,7 @@ def test_loader_custom_content_key_error(self, client):
instance_id,
google_database,
query,
client,
client=client,
content_columns=["NOT_A_COLUMN"],
)
docs = loader.load()
Expand All @@ -379,7 +379,7 @@ def test_loader_custom_metadata_key_error(self, client):
instance_id,
google_database,
query,
client,
client=client,
metadata_columns=["NOT_A_COLUMN"],
)
docs = loader.load()
Expand All @@ -405,7 +405,7 @@ def test_loader_custom_json_metadata(self, client):
instance_id,
google_database,
table_name,
client,
client=client,
content_column="product_id",
metadata_columns=["product_name", "description", "price"],
metadata_json_column="my_metadata",
Expand Down Expand Up @@ -471,7 +471,7 @@ def setup_database(self, client):
instance_id,
pg_database,
table_name,
client,
client=client,
content_column="product_id",
metadata_columns=["product_name", "description", "price", "dummy_col"],
)
Expand Down Expand Up @@ -768,7 +768,7 @@ def test_loader_custom_format_error(self, client):
instance_id,
pg_database,
query,
client,
client=client,
format="NOT_A_FORMAT",
)

Expand All @@ -779,7 +779,7 @@ def test_loader_custom_content_key_error(self, client):
instance_id,
pg_database,
query,
client,
client=client,
content_columns=["NOT_A_COLUMN"],
)
docs = loader.load()
Expand All @@ -791,7 +791,7 @@ def test_loader_custom_metadata_key_error(self, client):
instance_id,
pg_database,
query,
client,
client=client,
metadata_columns=["NOT_A_COLUMN"],
)
docs = loader.load()
Expand All @@ -817,7 +817,7 @@ def test_loader_custom_json_metadata(self, client):
instance_id,
pg_database,
table_name,
client,
client=client,
content_column="product_id",
metadata_columns=["product_name", "description", "price"],
metadata_json_column="my_metadata",
Expand Down Expand Up @@ -881,7 +881,7 @@ def test_saver_google_sql(self, google_client):
instance_id, google_database, table_name
)
saver = SpannerDocumentSaver(
instance_id, google_database, table_name, google_client
instance_id, google_database, table_name, client=google_client
)
query = f"SELECT * FROM {table_name}"
loader = SpannerLoader(
Expand All @@ -901,7 +901,9 @@ def test_saver_google_sql(self, google_client):

def test_saver_pg(self, pg_client):
SpannerDocumentSaver.init_document_table(instance_id, pg_database, table_name)
saver = SpannerDocumentSaver(instance_id, pg_database, table_name, pg_client)
saver = SpannerDocumentSaver(
instance_id, pg_database, table_name, client=pg_client
)
query = f"SELECT * FROM {table_name}"
loader = SpannerLoader(
client=pg_client,
Expand Down Expand Up @@ -935,7 +937,7 @@ def test_saver_google_sql_with_custom_schema(self, google_client):
instance_id,
google_database,
table_name,
google_client,
client=google_client,
content_column="my_page_content",
)
query = f"SELECT * FROM {table_name}"
Expand Down Expand Up @@ -981,7 +983,7 @@ def test_saver_pg_with_custom_schema(self, pg_client):
instance_id,
pg_database,
table_name,
pg_client,
client=pg_client,
content_column="my_page_content",
)
query = f"SELECT * FROM {table_name}"
Expand Down Expand Up @@ -1015,7 +1017,7 @@ def test_delete(self, google_client):
instance_id, google_database, table_name
)
saver = SpannerDocumentSaver(
instance_id, google_database, table_name, google_client
instance_id, google_database, table_name, client=google_client
)
query = f"SELECT * FROM {table_name}"
loader = SpannerLoader(
Expand All @@ -1040,7 +1042,7 @@ def test_saver_with_bad_docs(self, google_client):
instance_id, google_database, table_name
)
saver = SpannerDocumentSaver(
instance_id, google_database, table_name, google_client
instance_id, google_database, table_name, client=google_client
)
with pytest.raises(Exception):
saver.add_documents([1, 2, 3])

0 comments on commit 8315d9d

Please sign in to comment.