Skip to content

Commit

Permalink
add user agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Feb 21, 2024
1 parent 5cea61c commit 169a5ea
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions src/langchain_google_spanner/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
import datetime
import json
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Union
from typing import Any, Dict, Iterator, List, Optional, Union

from google.cloud.spanner import Client, KeySet # type: ignore
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect # type: ignore
from google.cloud.spanner_v1.data_types import JsonObject # type: ignore
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document

from .version import __version__

USER_AGENT_LOADER = "langchain-google-spanner-python:document_loader" + __version__

OPERATION_TIMEOUT_SECONDS = 240
MUTATION_BATCH_SIZE = 1000

Expand All @@ -37,6 +41,17 @@ class Column:
nullable: bool = True


def client_with_user_agent(client: Optional[Client], user_agent: str) -> Client:
if not client:
client = Client()
client_agent = client._client_info.user_agent
if not client_agent:
client._client_info.user_agent = user_agent
elif user_agent not in client_agent:
client._client_info.user_agent = " ".join([client_agent, user_agent])
return client


def _load_row_to_doc(
format: str,
content_columns: List[str],
Expand Down Expand Up @@ -123,21 +138,20 @@ def __init__(
instance_id: str,
database_id: str,
query: str,
client: Client = Client(),
content_columns: List[str] = [],
metadata_columns: List[str] = [],
format: str = "text",
databoost: bool = False,
metadata_json_column: str = METADATA_COL_NAME,
staleness: Union[float, datetime.datetime] = 0.0,
client: Optional[Client] = None,
):
"""Initialize Spanner document loader.
Args:
instance_id: The Spanner instance to load data from.
database_id: The Spanner database to load data from.
query: A GoogleSQL or PostgreSQL query. Users must match dialect to their database.
client: The connection object to use. This can be used to customize project id and credentials.
content_columns: The list of column(s) or field(s) to use for a Document's page content.
Page content is the default field for embeddings generation.
metadata_columns: The list of column(s) or field(s) to use for metadata.
Expand All @@ -146,6 +160,7 @@ def __init__(
databoost: Use data boost on read. Note: needs extra IAM permissions and higher cost.
metadata_json_column: The name of the JSON column to use as the metadata's base dictionary.
staleness: The time bound for stale read. Takes either a datetime or float.
client: The connection object to use. This can be used to customize project id and credentials.
"""
self.instance_id = instance_id
self.database_id = database_id
Expand All @@ -158,7 +173,7 @@ def __init__(
if self.format not in formats:
raise Exception("Use one of 'text', 'JSON', 'YAML', 'CSV'.")
self.databoost = databoost
self.client = client
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
self.staleness = staleness
instance = self.client.instance(instance_id)
if not instance.exists():
Expand Down Expand Up @@ -230,30 +245,30 @@ def __init__(
instance_id: str,
database_id: str,
table_name: str,
client: Client = Client(),
content_column: str = CONTENT_COL_NAME,
metadata_columns: List[str] = [],
metadata_json_column: str = METADATA_COL_NAME,
client: Optional[Client] = None,
):
"""Initialize Spanner document saver.
Args:
instance_id: The Spanner instance to load data to.
database_id: The Spanner database to load data to.
table_name: The table name to load data to.
client: The connection object to use. This can be used to customized project id and credentials.
content_column: The name of the content column. Defaulted to the first column.
metadata_columns: This is for user to opt-in a selection of columns to use. Defaulted to use
all columns.
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
client: The connection object to use. This can be used to customized project id and credentials.
"""
self.instance_id = instance_id
self.database_id = database_id
self.table_name = table_name
self.content_column = content_column
self.metadata_columns = metadata_columns
self.metadata_json_column = metadata_json_column
self.client = client
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
instance = self.client.instance(instance_id)
if not instance.exists():
raise Exception("Instance doesn't exist.")
Expand Down Expand Up @@ -349,7 +364,7 @@ def init_document_table(
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
"""
primary_key = primary_key or content_column
client = Client()
client = client_with_user_agent(None, USER_AGENT_LOADER)
metadata_json_column = metadata_json_column if store_metadata else ""
instance = client.instance(instance_id)
if not instance.exists():
Expand Down

0 comments on commit 169a5ea

Please sign in to comment.