Skip to content

Commit

Permalink
Allow passing in database for fetch_last_updated_timestamps (#20466)
Browse files Browse the repository at this point in the history
If the user does not have the database configured on their resource, as
is the case in our open platform, then they need to provide a fully
qualified db name as part of the table name.

How I tested this
Added a parameterization for providing db name via resource, and also
via the method.
  • Loading branch information
dpeng817 authored and jamiedemaria committed Mar 14, 2024
1 parent 145c117 commit 79b8383
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def fetch_last_updated_timestamps(
snowflake_connection: Union[SqlDbConnection, snowflake.connector.SnowflakeConnection],
schema: str,
tables: Sequence[str],
database: Optional[str] = None,
) -> Mapping[str, datetime]:
"""Fetch the last updated times of a list of tables in Snowflake.
Expand All @@ -730,17 +731,23 @@ def fetch_last_updated_timestamps(
snowflake_connection (Union[SqlDbConnection, SnowflakeConnection]): A connection to Snowflake.
Accepts either a SnowflakeConnection or a sqlalchemy connection object,
which are the two types of connections emittable from the snowflake resource.
schema (str): The schema of the table.
schema (str): The schema of the tables to fetch the last updated time for.
tables (Sequence[str]): A list of table names to fetch the last updated time for.
database (Optional[str]): The database of the table. Only required if the connection
has not been set with a database.
Returns:
Mapping[str, datetime]: A dictionary of table names to their last updated time in UTC.
"""
check.invariant(len(tables) > 0, "Must provide at least one table name to query upon.")
tables_str = ", ".join([f"'{table_name}'" for table_name in tables])
fully_qualified_table_name = (
f"{database}.information_schema.tables" if database else "information_schema.tables"
)

query = f"""
SELECT table_name, CONVERT_TIMEZONE('UTC', last_altered) AS last_altered
FROM information_schema.tables
FROM {fully_qualified_table_name}
WHERE table_schema = '{schema}' AND table_name IN ({tables_str});
"""
result = snowflake_connection.cursor().execute(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,15 @@ def test_pydantic_snowflake_resource_duplicate_auth():

def test_fetch_last_updated_timestamps_empty():
with pytest.raises(CheckError):
fetch_last_updated_timestamps(snowflake_connection={}, schema="TESTSCHEMA", tables=[])
fetch_last_updated_timestamps(
snowflake_connection={}, schema="TESTSCHEMA", database="TESTDB", tables=[]
)


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
@pytest.mark.integration
def test_fetch_last_updated_timestamps():
@pytest.mark.parametrize("db_str", [None, "TESTDB"])
def test_fetch_last_updated_timestamps(db_str: str):
start_time = pendulum.now("UTC").timestamp()
table_name = "the_table"
with temporary_snowflake_table() as table_name:
Expand All @@ -301,7 +304,10 @@ def test_fetch_last_updated_timestamps():
def freshness_observe(snowflake: SnowflakeResource) -> ObserveResult:
with snowflake.get_connection() as conn:
freshness_for_table = fetch_last_updated_timestamps(
snowflake_connection=conn, schema="TESTSCHEMA", tables=[table_name]
snowflake_connection=conn,
database="TESTDB",
tables=[table_name],
schema="TESTSCHEMA",
)[table_name].timestamp()
return ObserveResult(
data_version=DataVersion("foo"),
Expand All @@ -317,8 +323,7 @@ def freshness_observe(snowflake: SnowflakeResource) -> ObserveResult:
account=os.getenv("SNOWFLAKE_ACCOUNT"),
user=os.environ["SNOWFLAKE_USER"],
password=os.getenv("SNOWFLAKE_PASSWORD"),
database="TESTDB",
schema="TESTSCHEMA",
database="TESTDB" if db_str is None else db_str,
)
},
)
Expand Down

0 comments on commit 79b8383

Please sign in to comment.