From 289e289b5bb69cf65fb25a1297504429bce81235 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 6 Dec 2024 17:33:32 +0100 Subject: [PATCH] add support for column schema in certion query cases --- dlt/destinations/dataset/dataset.py | 2 +- dlt/destinations/dataset/ibis_relation.py | 35 +++- tests/load/test_read_interfaces.py | 191 +++++++++++----------- 3 files changed, 131 insertions(+), 97 deletions(-) diff --git a/dlt/destinations/dataset/dataset.py b/dlt/destinations/dataset/dataset.py index 7ef407b115..e443045e49 100644 --- a/dlt/destinations/dataset/dataset.py +++ b/dlt/destinations/dataset/dataset.py @@ -121,7 +121,7 @@ def table(self, table_name: str) -> SupportsReadableRelation: from dlt.destinations.dataset.ibis_relation import ReadableIbisRelation unbound_table = create_unbound_ibis_table(self.sql_client, self.schema, table_name) - return ReadableIbisRelation(readable_dataset=self, ibis_object=unbound_table) # type: ignore[abstract] + return ReadableIbisRelation(readable_dataset=self, ibis_object=unbound_table, columns_schema=self.schema.tables[table_name]["columns"]) # type: ignore[abstract] except MissingDependencyException: # if ibis is explicitly requested, reraise if self._dataset_type == "ibis": diff --git a/dlt/destinations/dataset/ibis_relation.py b/dlt/destinations/dataset/ibis_relation.py index bed5a9c883..632298ad56 100644 --- a/dlt/destinations/dataset/ibis_relation.py +++ b/dlt/destinations/dataset/ibis_relation.py @@ -47,17 +47,18 @@ } -# TODO: provide ibis expression typing for the readable relation class ReadableIbisRelation(BaseReadableDBAPIRelation): def __init__( self, *, readable_dataset: ReadableDBAPIDataset, ibis_object: Any = None, + columns_schema: TTableSchemaColumns = None, ) -> None: """Create a lazy evaluated relation to for the dataset of a destination""" super().__init__(readable_dataset=readable_dataset) self._ibis_object = ibis_object + self._columns_schema = columns_schema @property def query(self) -> Any: @@ -89,7 +90,7 @@ def columns_schema(self, new_value: TTableSchemaColumns) -> None: def compute_columns_schema(self) -> TTableSchemaColumns: """provide schema columns for the cursor, may be filtered by selected columns""" # TODO: provide column lineage tracing with sqlglot lineage - return None + return self._columns_schema def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: """Proxy method calls to the underlying ibis expression, allowing to wrap the resulting expression in a new relation""" @@ -119,8 +120,18 @@ def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any) # Call it with provided args result = method(*args, **kwargs) + # calculate columns schema for the result, some operations we know will not change the schema + # and select will just reduce the amount of column + columns_schema = None + if method_name == "select": + columns_schema = self._get_filtered_columns_schema(args) + elif method_name in ["filter", "limit", "order_by", "head"]: + columns_schema = self._columns_schema + # If result is an ibis expression, wrap it in a new relation else return raw result - return self.__class__(readable_dataset=self._dataset, ibis_object=result) + return self.__class__( + readable_dataset=self._dataset, ibis_object=result, columns_schema=columns_schema + ) def __getattr__(self, name: str) -> Any: """Wrap all callable attributes of the expression""" @@ -136,15 +147,31 @@ def __getattr__(self, name: str) -> Any: raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") if not callable(attr): + # NOTE: we don't need to forward columns schema for non-callable attributes, these are usually columns return self.__class__(readable_dataset=self._dataset, ibis_object=attr) return partial(self._proxy_expression_method, name) def __getitem__(self, columns: Union[str, Sequence[str]]) -> "ReadableIbisRelation": # casefold column-names + columns = [columns] if isinstance(columns, str) else columns columns = [self.sql_client.capabilities.casefold_identifier(col) for col in columns] expr = self._ibis_object[columns] - return self.__class__(readable_dataset=self._dataset, ibis_object=expr) + return self.__class__( + readable_dataset=self._dataset, + ibis_object=expr, + columns_schema=self._get_filtered_columns_schema(columns), + ) + + def _get_filtered_columns_schema(self, columns: Sequence[str]) -> TTableSchemaColumns: + if not self._columns_schema: + return None + try: + return {col: self._columns_schema[col] for col in columns} + except KeyError: + # NOTE: select statements can contain new columns not present in the original schema + # here we just break the column schema inheritance chain + return None # forward ibis methods defined on interface def limit(self, limit: int, **kwargs: Any) -> "ReadableIbisRelation": diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 6488980cc2..d2f5f7951e 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, cast, Tuple, List import re import pytest import dlt @@ -493,142 +493,149 @@ def test_ibis_expression_relation(populated_pipeline: Pipeline) -> None: return # we check a bunch of expressions without executing them to see that they produce correct sql - def sql_from_expr(expr: Any) -> str: + # also we return the keys of the disovered schema columns + def sql_from_expr(expr: Any) -> Tuple[str, List[str]]: query = str(expr.query).replace(populated_pipeline.dataset_name, "dataset") - return re.sub(r"\s+", " ", query) + columns = list(expr.columns_schema.keys()) if expr.columns_schema else None + return re.sub(r"\s+", " ", query), columns # test all functions discussed here: https://ibis-project.org/tutorials/ibis-for-sql-users + ALL_COLUMNS = ["id", "decimal", "other_decimal", "_dlt_load_id", "_dlt_id"] # selecting two columns - assert ( - sql_from_expr(items_table.select("id", "decimal")) - == 'SELECT "t0"."id", "t0"."decimal" FROM "dataset"."items" AS "t0"' + assert sql_from_expr(items_table.select("id", "decimal")) == ( + 'SELECT "t0"."id", "t0"."decimal" FROM "dataset"."items" AS "t0"', + ["id", "decimal"], + ) + + # selecting all columns + assert sql_from_expr(items_table) == ('SELECT * FROM "dataset"."items"', ALL_COLUMNS) + + # selecting two other columns via item getter + assert sql_from_expr(items_table["id", "decimal"]) == ( + 'SELECT "t0"."id", "t0"."decimal" FROM "dataset"."items" AS "t0"', + ["id", "decimal"], ) # adding a new columns new_col = (items_table.id * 2).name("new_col") - assert ( - sql_from_expr(items_table.select("id", "decimal", new_col)) - == 'SELECT "t0"."id", "t0"."decimal", "t0"."id" * 2 AS "new_col" FROM "dataset"."items" AS' - ' "t0"' + assert sql_from_expr(items_table.select("id", "decimal", new_col)) == ( + ( + 'SELECT "t0"."id", "t0"."decimal", "t0"."id" * 2 AS "new_col" FROM' + ' "dataset"."items" AS "t0"' + ), + None, ) # mutating table (add a new column computed from existing columns) - assert ( - sql_from_expr(items_table.mutate(double_id=items_table.id * 2).select("id", "double_id")) - == 'SELECT "t0"."id", "t0"."id" * 2 AS "double_id" FROM "dataset"."items" AS "t0"' + assert sql_from_expr( + items_table.mutate(double_id=items_table.id * 2).select("id", "double_id") + ) == ( + 'SELECT "t0"."id", "t0"."id" * 2 AS "double_id" FROM "dataset"."items" AS "t0"', + None, ) # mutating table add new static column - assert ( - sql_from_expr( - items_table.mutate(new_col=ibis.literal("static_value")).select("id", "new_col") - ) - == 'SELECT "t0"."id", \'static_value\' AS "new_col" FROM "dataset"."items" AS "t0"' - ) - - # check filtering - assert ( - sql_from_expr(items_table.filter(items_table.id < 10)) - == 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10' + assert sql_from_expr( + items_table.mutate(new_col=ibis.literal("static_value")).select("id", "new_col") + ) == ('SELECT "t0"."id", \'static_value\' AS "new_col" FROM "dataset"."items" AS "t0"', None) + + # check filtering (preserves all columns) + assert sql_from_expr(items_table.filter(items_table.id < 10)) == ( + 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10', + ALL_COLUMNS, ) # filtering and selecting a single column - assert ( - sql_from_expr(items_table.filter(items_table.id < 10).select("id")) - == 'SELECT "t0"."id" FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10' + assert sql_from_expr(items_table.filter(items_table.id < 10).select("id")) == ( + 'SELECT "t0"."id" FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10', + ["id"], ) - # check filter and - assert ( - sql_from_expr(items_table.filter(items_table.id < 10).filter(items_table.id > 5)) - == 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10 AND "t0"."id" > 5' + # check filter "and" condition + assert sql_from_expr(items_table.filter(items_table.id < 10).filter(items_table.id > 5)) == ( + 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10 AND "t0"."id" > 5', + ALL_COLUMNS, ) - # check filter or - assert ( - sql_from_expr(items_table.filter((items_table.id < 10) | (items_table.id > 5))) - == 'SELECT * FROM "dataset"."items" AS "t0" WHERE ( "t0"."id" < 10 ) OR ( "t0"."id" > 5 )' + # check filter "or" condition + assert sql_from_expr(items_table.filter((items_table.id < 10) | (items_table.id > 5))) == ( + 'SELECT * FROM "dataset"."items" AS "t0" WHERE ( "t0"."id" < 10 ) OR ( "t0"."id" > 5 )', + ALL_COLUMNS, ) # check group by and aggregate - assert ( - sql_from_expr( - items_table.group_by("id") - .having(items_table.count() >= 1000) - .aggregate(sum_id=items_table.id.sum()) - ) - == 'SELECT "t1"."id", "t1"."sum_id" FROM ( SELECT "t0"."id", SUM("t0"."id") AS "sum_id",' - ' COUNT(*) AS "CountStar(items)" FROM "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1"' - ' WHERE "t1"."CountStar(items)" >= 1000' + assert sql_from_expr( + items_table.group_by("id") + .having(items_table.count() >= 1000) + .aggregate(sum_id=items_table.id.sum()) + ) == ( + ( + 'SELECT "t1"."id", "t1"."sum_id" FROM ( SELECT "t0"."id", SUM("t0"."id") AS "sum_id",' + ' COUNT(*) AS "CountStar(items)" FROM "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1"' + ' WHERE "t1"."CountStar(items)" >= 1000' + ), + None, ) # sorting and ordering - assert ( - sql_from_expr(items_table.order_by("id", "decimal").limit(10)) - == 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC, "t0"."decimal" ASC' - " LIMIT 10" + assert sql_from_expr(items_table.order_by("id", "decimal").limit(10)) == ( + ( + 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC, "t0"."decimal" ASC' + " LIMIT 10" + ), + ALL_COLUMNS, ) # sort desc and asc - assert ( - sql_from_expr(items_table.order_by(ibis.desc("id"), ibis.asc("decimal")).limit(10)) - == 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" DESC, "t0"."decimal" ASC' - " LIMIT 10" + assert sql_from_expr(items_table.order_by(ibis.desc("id"), ibis.asc("decimal")).limit(10)) == ( + ( + 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" DESC, "t0"."decimal" ASC' + " LIMIT 10" + ), + ALL_COLUMNS, ) # offset and limit - assert ( - sql_from_expr(items_table.order_by("id").limit(10, offset=5)) - == 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC LIMIT 10 OFFSET 5' + assert sql_from_expr(items_table.order_by("id").limit(10, offset=5)) == ( + 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC LIMIT 10 OFFSET 5', + ALL_COLUMNS, ) # join - assert ( - sql_from_expr( - items_table.join(double_items_table, items_table.id == double_items_table.id)[ - ["id", "double_id"] - ] - ) - == 'SELECT "t2"."id", "t3"."double_id" FROM "dataset"."items" AS "t2" INNER JOIN' - ' "dataset"."double_items" AS "t3" ON "t2"."id" = "t3"."id"' + assert sql_from_expr( + items_table.join(double_items_table, items_table.id == double_items_table.id)[ + ["id", "double_id"] + ] + ) == ( + ( + 'SELECT "t2"."id", "t3"."double_id" FROM "dataset"."items" AS "t2" INNER JOIN' + ' "dataset"."double_items" AS "t3" ON "t2"."id" = "t3"."id"' + ), + None, ) # subqueries - assert ( - sql_from_expr(items_table.filter(items_table.decimal.isin(double_items_table.di_decimal))) - == 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."decimal" IN ( SELECT' - ' "t1"."di_decimal" FROM "dataset"."double_items" AS "t1" )' + assert sql_from_expr( + items_table.filter(items_table.decimal.isin(double_items_table.di_decimal)) + ) == ( + ( + 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."decimal" IN ( SELECT' + ' "t1"."di_decimal" FROM "dataset"."double_items" AS "t1" )' + ), + ALL_COLUMNS, ) # topk - assert ( - sql_from_expr(items_table.decimal.topk(10)) - == 'SELECT * FROM ( SELECT "t0"."decimal", COUNT(*) AS "CountStar(items)" FROM' - ' "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1" ORDER BY "t1"."CountStar(items)" DESC' - " LIMIT 10" + assert sql_from_expr(items_table.decimal.topk(10)) == ( + ( + 'SELECT * FROM ( SELECT "t0"."decimal", COUNT(*) AS "CountStar(items)" FROM' + ' "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1" ORDER BY "t1"."CountStar(items)" DESC' + " LIMIT 10" + ), + None, ) - # NOTE: here we test that dlt column type resolution still works - # re-enable this when lineage is implemented - # expected_decimal_precision = 10 - # expected_decimal_precision_2 = 12 - # expected_decimal_precision_di = 7 - # if populated_pipeline.destination.destination_type == "dlt.destinations.bigquery": - # # bigquery does not allow precision configuration.. - # expected_decimal_precision = 38 - # expected_decimal_precision_2 = 38 - # expected_decimal_precision_di = 38 - - # joined_table = items_table.join(double_items_table, items_table.id == double_items_table.id)[ - # ["decimal", "other_decimal", "di_decimal"] - # ].rename(decimal_renamed="di_decimal").limit(20) - # table = joined_table.arrow() - # print(joined_table.compute_columns_schema(force=True)) - # assert table.schema.field("decimal").type.precision == expected_decimal_precision - # assert table.schema.field("other_decimal").type.precision == expected_decimal_precision_2 - # assert table.schema.field("di_decimal").type.precision == expected_decimal_precision_di - @pytest.mark.no_load @pytest.mark.essential