From 17cdb3a92a15f45bd2f9327ce32dc8c3c9e8b82c Mon Sep 17 00:00:00 2001 From: Daniel Gafni <49863538+danielgafni@users.noreply.github.com> Date: Fri, 30 Jun 2023 15:19:33 +0200 Subject: [PATCH] :star: add PolarsDeltaIOManager (#7) * :star: add `PolarsDeltaIOManager` * :wastebasket: remove legacy `polars_parquet_io_manager` --- .github/workflows/check.yml | 6 +- README.md | 12 +- dagster_polars/__init__.py | 5 +- dagster_polars/io_managers/base.py | 48 ++++++++ dagster_polars/io_managers/delta.py | 75 ++++++++++++ dagster_polars/io_managers/parquet.py | 45 ++++--- poetry.lock | 71 +++++++++++- pyproject.toml | 4 +- tests/conftest.py | 100 ++++++++++++---- tests/example.py | 11 +- tests/test_deltalake.py | 37 ++++++ tests/test_polars_delta.py | 160 +++++++++++++++++++++++++ tests/test_polars_parquet.py | 161 ++------------------------ tests/test_upath_io_managers.py | 141 ++++++++++++++++++++++ 14 files changed, 672 insertions(+), 204 deletions(-) create mode 100644 dagster_polars/io_managers/delta.py create mode 100644 tests/test_deltalake.py create mode 100644 tests/test_polars_delta.py create mode 100644 tests/test_upath_io_managers.py diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 37d4b62..8141452 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -36,7 +36,9 @@ jobs: virtualenvs-in-project: false installer-parallel: true - name: Install dependencies - run: poetry install --all-extras --sync && pip install polars~=${{ matrix.polars_version }} + run: poetry install --all-extras --sync && pip install --ignore-installed polars~=${{ matrix.polars_version }} + - name: Print polars info + run: python -c 'import polars; print(polars.show_versions())' - name: Run tests run: pytest -v . @@ -70,6 +72,6 @@ jobs: virtualenvs-in-project: false installer-parallel: true - name: Install dependencies - run: poetry install --all-extras --sync && pip install polars~=${{ matrix.polars_version }} + run: poetry install --all-extras --sync && pip install --ignore-installed polars~=${{ matrix.polars_version }} - name: Run pre-commit hooks run: pre-commit run --all-files diff --git a/README.md b/README.md index 7438127..54ee314 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,9 @@ - inherits all the features of the `UPathIOManager` - works with local and remote filesystems (like S3), supports loading multiple partitions (use `dict[str, pl.DataFrame]` type annotation), ... - Implemented serialization formats: - - `PolarsParquetIOManager` - for reading and writing files in Apache Parquet format. Supports reading partitioned Parquet datasets (for example, often produced by Spark). - - `BigQueryPolarsIOManager` - for reading and writing data from/to [BigQuery](https://cloud.google.com/bigquery). Supports writing partitioned tables (`"partition_expr"` input metadata key must be specified). + - `PolarsParquetIOManager` - for reading and writing files in Apache Parquet format. Supports reading partitioned Parquet datasets (for example, often produced by Spark). All read/write options can be set via metadata values. + - `PolarsDeltaIOManager` - for reading and writing Delta Lake. All read/write options can be set via metadata values. `"partition_by"` metadata value can be set to use native Delta Lake partitioning (it's passed to `delta_write_options` of `write_delta`). In this case, all the asset partitions will be stored in the same Delta Table directory. You are responsible for filtering correct partitions when reading the data in the downstream assets. Extra dependencies can be installed with `pip install 'dagster-polars[deltalake]'`. + - `BigQueryPolarsIOManager` - for reading and writing data from/to [BigQuery](https://cloud.google.com/bigquery). Supports writing partitioned tables (`"partition_expr"` input metadata key must be specified). Extra dependencies can be installed with `pip install 'dagster-polars[gcp]'`. ## Quickstart @@ -67,7 +68,6 @@ poetry run pre-commit install poetry run pytest ``` -## TODO - - [ ] Add `PolarsDeltaIOManager` - - [ ] Data validation like in [dagster-pandas](https://docs.dagster.io/integrations/pandas#validating-pandas-dataframes-with-dagster-types) - - [ ] Maybe use `DagsterTypeLoader` ? +## Ideas + - Data validation like in [dagster-pandas](https://docs.dagster.io/integrations/pandas#validating-pandas-dataframes-with-dagster-types) + - Maybe use `DagsterTypeLoader` ? diff --git a/dagster_polars/__init__.py b/dagster_polars/__init__.py index d5ec052..34fc661 100644 --- a/dagster_polars/__init__.py +++ b/dagster_polars/__init__.py @@ -1,5 +1,6 @@ from dagster_polars._version import __version__ from dagster_polars.io_managers.base import BasePolarsUPathIOManager -from dagster_polars.io_managers.parquet import PolarsParquetIOManager, polars_parquet_io_manager +from dagster_polars.io_managers.delta import PolarsDeltaIOManager +from dagster_polars.io_managers.parquet import PolarsParquetIOManager -__all__ = ["PolarsParquetIOManager", "BasePolarsUPathIOManager", "polars_parquet_io_manager", "__version__"] +__all__ = ["PolarsParquetIOManager", "PolarsDeltaIOManager", "BasePolarsUPathIOManager", "__version__"] diff --git a/dagster_polars/io_managers/base.py b/dagster_polars/io_managers/base.py index 71e3861..85c3f5e 100644 --- a/dagster_polars/io_managers/base.py +++ b/dagster_polars/io_managers/base.py @@ -11,6 +11,7 @@ InitResourceContext, InputContext, MetadataValue, + MultiPartitionKey, OutputContext, TableColumn, TableMetadataValue, @@ -178,3 +179,50 @@ def load_from_path(self, path: UPath, context: InputContext) -> Union[pl.DataFra def get_metadata(self, context: OutputContext, obj: pl.DataFrame) -> Dict[str, MetadataValue]: return get_polars_metadata(context, obj) + + @staticmethod + def get_storage_options(path: UPath) -> dict: + storage_options = {} + + try: + storage_options.update(path._kwargs.copy()) + except AttributeError: + pass + + return storage_options + + def get_path_for_partition(self, context: Union[InputContext, OutputContext], path: UPath, partition: str) -> UPath: + """ + Override this method if you want to use a different partitioning scheme + (for example, if the saving function handles partitioning instead). + The extension will be added later. + :param context: + :param path: asset path before partitioning + :param partition: formatted partition key + :return: + """ + return path / partition + + def _get_paths_for_partitions(self, context: Union[InputContext, OutputContext]) -> Dict[str, "UPath"]: + """Returns a dict of partition_keys into I/O paths for a given context.""" + if not context.has_asset_partitions: + raise TypeError( + f"Detected {context.dagster_type.typing_type} input type " "but the asset is not partitioned" + ) + + def _formatted_multipartitioned_path(partition_key: MultiPartitionKey) -> str: + ordered_dimension_keys = [ + key[1] for key in sorted(partition_key.keys_by_dimension.items(), key=lambda x: x[0]) + ] + return "/".join(ordered_dimension_keys) + + formatted_partition_keys = [ + _formatted_multipartitioned_path(pk) if isinstance(pk, MultiPartitionKey) else pk + for pk in context.asset_partition_keys + ] + + asset_path = self._get_path_without_extension(context) + return { + partition: self._with_extension(self.get_path_for_partition(context, asset_path, partition)) + for partition in formatted_partition_keys + } diff --git a/dagster_polars/io_managers/delta.py b/dagster_polars/io_managers/delta.py new file mode 100644 index 0000000..3ee27c1 --- /dev/null +++ b/dagster_polars/io_managers/delta.py @@ -0,0 +1,75 @@ +from pprint import pformat +from typing import Union + +import polars as pl +from dagster import InputContext, OutputContext +from deltalake import DeltaTable +from upath import UPath + +from dagster_polars.io_managers.base import BasePolarsUPathIOManager + + +class PolarsDeltaIOManager(BasePolarsUPathIOManager): + extension: str = ".delta" + + assert BasePolarsUPathIOManager.__doc__ is not None + __doc__ = ( + BasePolarsUPathIOManager.__doc__ + + """\nWorks with Delta files. + All read/write arguments can be passed via corresponding metadata values.""" + ) + + def get_path_for_partition(self, context: Union[InputContext, OutputContext], path: UPath, partition: str) -> UPath: + if isinstance(context, InputContext): + if ( + context.upstream_output is not None + and context.upstream_output.metadata is not None + and context.upstream_output.metadata.get("partition_by") is not None + ): + # upstream asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself + return path + + if isinstance(context, OutputContext): + if context.metadata is not None and context.metadata.get("partition_by") is not None: + # this asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself + return path + + return path / partition # partitioning is handled by the IOManager + + def dump_df_to_path(self, context: OutputContext, df: pl.DataFrame, path: UPath): + assert context.metadata is not None + + delta_write_options = context.metadata.get("delta_write_options") + + if context.has_asset_partitions: + delta_write_options = delta_write_options or {} + partition_by = context.metadata.get("partition_by") + + if partition_by is not None: + delta_write_options["partition_by"] = partition_by + + if delta_write_options is not None: + context.log.debug(f"Writing with delta_write_options: {pformat(delta_write_options)}") + + storage_options = self.get_storage_options(path) + + df.write_delta( + str(path), + mode=context.metadata.get("mode", "overwrite"), # type: ignore + overwrite_schema=context.metadata.get("overwrite_schema", False), + storage_options=storage_options, + delta_write_options=delta_write_options, + ) + table = DeltaTable(str(path), storage_options=storage_options) + context.add_output_metadata({"version": table.version()}) + + def scan_df_from_path(self, path: UPath, context: InputContext) -> pl.LazyFrame: + assert context.metadata is not None + + return pl.scan_delta( + str(path), + version=context.metadata.get("version"), + delta_table_options=context.metadata.get("delta_table_options"), + pyarrow_options=context.metadata.get("pyarrow_options"), + storage_options=self.get_storage_options(path), + ) diff --git a/dagster_polars/io_managers/parquet.py b/dagster_polars/io_managers/parquet.py index 16bce3e..5daf35e 100644 --- a/dagster_polars/io_managers/parquet.py +++ b/dagster_polars/io_managers/parquet.py @@ -3,7 +3,7 @@ import fsspec import polars as pl import pyarrow.dataset as ds -from dagster import InitResourceContext, InputContext, OutputContext, io_manager +from dagster import InputContext, OutputContext from upath import UPath from dagster_polars.io_managers.base import BasePolarsUPathIOManager @@ -12,13 +12,30 @@ class PolarsParquetIOManager(BasePolarsUPathIOManager): extension: str = ".parquet" - __doc__ = BasePolarsUPathIOManager.__doc__ + """\nWorks with Parquet files""" # type: ignore + assert BasePolarsUPathIOManager.__doc__ is not None + __doc__ = ( + BasePolarsUPathIOManager.__doc__ + + """\nWorks with Parquet files. + All read/write arguments can be passed via corresponding metadata values.""" + ) def dump_df_to_path(self, context: OutputContext, df: pl.DataFrame, path: UPath): + assert context.metadata is not None + with path.open("wb") as file: - df.write_parquet(file) + df.write_parquet( + file, + compression=context.metadata.get("compression", "zstd"), + compression_level=context.metadata.get("compression_level"), + statistics=context.metadata.get("statistics", False), + row_group_size=context.metadata.get("row_group_size"), + use_pyarrow=context.metadata.get("use_pyarrow", False), + pyarrow_options=context.metadata.get("pyarrow_options"), + ) def scan_df_from_path(self, path: UPath, context: InputContext) -> pl.LazyFrame: + assert context.metadata is not None + fs: Union[fsspec.AbstractFileSystem, None] = None try: @@ -26,13 +43,15 @@ def scan_df_from_path(self, path: UPath, context: InputContext) -> pl.LazyFrame: except AttributeError: pass - return pl.scan_pyarrow_dataset(ds.dataset(str(path), filesystem=fs)) - - -# old non-pythonic IOManager, you are encouraged to use the `PolarsParquetIOManager` instead -@io_manager( - config_schema=PolarsParquetIOManager.to_config_schema(), - description=PolarsParquetIOManager.__doc__, -) -def polars_parquet_io_manager(context: InitResourceContext): - return PolarsParquetIOManager.from_resource_context(context) + return pl.scan_pyarrow_dataset( + ds.dataset( + str(path), + filesystem=fs, + format=context.metadata.get("format", "parquet"), + partitioning=context.metadata.get("partitioning"), + partition_base_dir=context.metadata.get("partition_base_dir"), + exclude_invalid_files=context.metadata.get("exclude_invalid_files", True), + ignore_prefixes=context.metadata.get("ignore_prefixes", [".", "_"]), + ), + allow_pyarrow_filter=context.metadata.get("allow_pyarrow_filter", True), + ) diff --git a/poetry.lock b/poetry.lock index 33831ed..bbf1b6f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -490,6 +490,21 @@ packaging = ">=17.0" pandas = ">=0.24.2" pyarrow = ">=3.0.0" +[[package]] +name = "decopatch" +version = "1.4.10" +description = "Create decorators easily in python." +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "decopatch-1.4.10-py2.py3-none-any.whl", hash = "sha256:e151f7f93de2b1b3fd3f3272dcc7cefd1a69f68ec1c2d8e288ecd9deb36dc5f7"}, + {file = "decopatch-1.4.10.tar.gz", hash = "sha256:957f49c93f4150182c23f8fb51d13bb3213e0f17a79e09c8cca7057598b55720"}, +] + +[package.dependencies] +makefun = ">=1.5.0" + [[package]] name = "deepdiff" version = "6.3.0" @@ -509,6 +524,31 @@ ordered-set = ">=4.0.2,<4.2.0" cli = ["click (==8.1.3)", "pyyaml (==6.0)"] optimize = ["orjson"] +[[package]] +name = "deltalake" +version = "0.10.0" +description = "Native Delta Lake Python binding based on delta-rs with Pandas integration" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "deltalake-0.10.0-cp37-abi3-macosx_10_7_x86_64.whl", hash = "sha256:1aef72679268324e7cc556f03969b92235076b6edcecae7a2f85ac930b5fb9b9"}, + {file = "deltalake-0.10.0-cp37-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:91e801b78e298946950b07bf0573e4cf012d6259ac294bcac662b6d746f95596"}, + {file = "deltalake-0.10.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:437f8ee9f28a9cd8454e8700c3dd88f5ad33c2ec9a7192356b55b23fee50259c"}, + {file = "deltalake-0.10.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f532b3c3dc8437fc26279a296bc576bdf79b167e1d4864847ca7c5ec7b3915b6"}, + {file = "deltalake-0.10.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42212ee879735d476772a9c1d1ca3e43ef15893a301796baffe94db3578e98ca"}, + {file = "deltalake-0.10.0-cp37-abi3-win_amd64.whl", hash = "sha256:88c7759d8a97ec882dd0309eb097506c1682637ca4a0dcbe04d6482eedd298c9"}, + {file = "deltalake-0.10.0.tar.gz", hash = "sha256:b8793a1c7a1219c8935ed9926a6870c9341597396c0f50bf6ff98c323d06ee0f"}, +] + +[package.dependencies] +pyarrow = ">=7" + +[package.extras] +devel = ["black", "mypy", "packaging (>=20)", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-timeout", "ruff", "sphinx (<=4.5)", "sphinx-rtd-theme", "toml", "wheel"] +pandas = ["pandas (<2)"] +pyspark = ["delta-spark", "numpy (==1.22.2)", "pyspark"] + [[package]] name = "distlib" version = "0.3.6" @@ -1469,6 +1509,18 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "makefun" +version = "1.15.1" +description = "Small library to dynamically create python functions." +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "makefun-1.15.1-py2.py3-none-any.whl", hash = "sha256:a63cfc7b47a539c76d97bd4fdb833c7d0461e759fd1225f580cb4be6200294d4"}, + {file = "makefun-1.15.1.tar.gz", hash = "sha256:40b0f118b6ded0d8d78c78f1eb679b8b6b2462e3c1b3e05fb1b2da8cd46b48a5"}, +] + [[package]] name = "mako" version = "1.2.4" @@ -2215,6 +2267,22 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "pytest-cases" +version = "3.6.14" +description = "Separate test code from test cases in pytest." +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "pytest-cases-3.6.14.tar.gz", hash = "sha256:7455e6ca57a544c1bfdd8b56ace08c1c1ce4c6572a8aab8f1bd351dc25a10b6b"}, + {file = "pytest_cases-3.6.14-py2.py3-none-any.whl", hash = "sha256:a087f3d019efd8942d0f0dc3fb526bedf9f83d742c40289e9623f6788aff7257"}, +] + +[package.dependencies] +decopatch = "*" +makefun = ">=1.9.5" + [[package]] name = "python-dateutil" version = "2.8.2" @@ -3101,9 +3169,10 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] +deltalake = ["deltalake"] gcp = ["dagster-gcp"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "997038d8fdd8dd2099d36ba6bd516ab174fcfda4eb9a23b2c28a75e0a9508874" +content-hash = "7ea8774076f190ff20ece104276bb592f1fbb8df88a608f665f1f2678d5fe1be" diff --git a/pyproject.toml b/pyproject.toml index e61c8bd..bddfaa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,12 +30,13 @@ python = "^3.8" dagster = "^1.3.5" polars = ">=0.17.0" pyarrow = ">=8.0.0" +deltalake = "^0.10.0" dagster-gcp = "^0.19.5" [tool.poetry.extras] gcp = ["dagster-gcp"] - +deltalake = ["deltalake"] [tool.poetry.group.dev.dependencies] @@ -50,6 +51,7 @@ tox-gh = "^1.0.0" pre-commit = "^3.3.2" dagit = "^1.3.9" black = "^23.3.0" +pytest-cases = "^3.6.14" [build-system] requires = ["poetry-core"] diff --git a/tests/conftest.py b/tests/conftest.py index 5872d71..64543e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,39 +1,95 @@ +import logging from datetime import date, datetime, timedelta +from typing import Tuple, Type import polars as pl import pytest +import pytest_cases from _pytest.tmpdir import TempPathFactory -from dagster import DagsterInstance, IOManagerDefinition +from dagster import DagsterInstance -from dagster_polars import PolarsParquetIOManager, polars_parquet_io_manager +from dagster_polars import BasePolarsUPathIOManager, PolarsDeltaIOManager, PolarsParquetIOManager +logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING) -@pytest.fixture(scope="session") -def dagster_instance(tmp_path_factory: TempPathFactory): + +@pytest.fixture +def dagster_instance(tmp_path_factory: TempPathFactory) -> DagsterInstance: return DagsterInstance.ephemeral(tempdir=str(tmp_path_factory.mktemp("dagster_home"))) -@pytest.fixture(scope="session") -def tmp_polars_parquet_io_manager(dagster_instance: DagsterInstance) -> PolarsParquetIOManager: +@pytest.fixture +def polars_parquet_io_manager(dagster_instance: DagsterInstance) -> PolarsParquetIOManager: return PolarsParquetIOManager(base_dir=dagster_instance.storage_directory()) +@pytest.fixture +def polars_delta_io_manager(dagster_instance: DagsterInstance) -> PolarsDeltaIOManager: + return PolarsDeltaIOManager(base_dir=dagster_instance.storage_directory()) + + @pytest.fixture(scope="session") -def tmp_polars_parquet_io_manager_legacy(dagster_instance: DagsterInstance) -> IOManagerDefinition: - return polars_parquet_io_manager.configured({"base_dir": dagster_instance.storage_directory()}) +def session_scoped_dagster_instance(tmp_path_factory: TempPathFactory) -> DagsterInstance: + return DagsterInstance.ephemeral(tempdir=str(tmp_path_factory.mktemp("dagster_home_session"))) -@pytest.fixture -def df(): - return pl.DataFrame( - { - "1": [0, 1, None], - "2": [0.0, 1.0, None], - "3": ["a", "b", None], - "4": [[0, 1], [2, 3], None], - "6": [{"a": 0}, {"a": 1}, None], - "7": [datetime(2022, 1, 1), datetime(2022, 1, 2), None], - "8": [date(2022, 1, 1), date(2022, 1, 2), None], - "9": [timedelta(hours=1), timedelta(hours=2), None], - } - ) +@pytest.fixture(scope="session") +def session_polars_parquet_io_manager(session_scoped_dagster_instance: DagsterInstance) -> PolarsParquetIOManager: + return PolarsParquetIOManager( + base_dir=session_scoped_dagster_instance.storage_directory() + ) # to use with hypothesis + + +@pytest.fixture(scope="session") +def session_polars_delta_io_manager(session_scoped_dagster_instance: DagsterInstance) -> PolarsDeltaIOManager: + return PolarsDeltaIOManager(base_dir=session_scoped_dagster_instance.storage_directory()) # to use with hypothesis + + +_df_for_parquet = pl.DataFrame( + { + "1": [0, 1, None], + "2": [0.0, 1.0, None], + "3": ["a", "b", None], + "4": [[0, 1], [2, 3], None], + "6": [{"a": 0}, {"a": 1}, None], + "7": [datetime(2022, 1, 1), datetime(2022, 1, 2), None], + "8": [date(2022, 1, 1), date(2022, 1, 2), None], + "9": [timedelta(hours=1), timedelta(hours=2), None], + } +) + + +@pytest_cases.fixture(scope="session") +def df_for_parquet() -> pl.DataFrame: + return _df_for_parquet + + +@pytest_cases.fixture(scope="session") +def df_for_delta() -> pl.DataFrame: + return _df_for_delta + + +# delta doesn't support Duration +# TODO: add timedeltas when supported +_df_for_delta = pl.DataFrame( + { + "1": [0, 1, None], + "2": [0.0, 1.0, None], + "3": ["a", "b", None], + "4": [[0, 1], [2, 3], None], + "6": [{"a": 0}, {"a": 1}, None], + "7": [datetime(2022, 1, 1), datetime(2022, 1, 2), None], + "8": [date(2022, 1, 1), date(2022, 1, 2), None], + } +) + + +@pytest_cases.fixture +@pytest_cases.parametrize( + "class_and_df", [(PolarsParquetIOManager, _df_for_parquet), (PolarsDeltaIOManager, _df_for_delta)] +) +def io_manager_and_df( # to use without hypothesis + class_and_df: Tuple[Type[BasePolarsUPathIOManager], pl.DataFrame], dagster_instance: DagsterInstance +) -> Tuple[BasePolarsUPathIOManager, pl.DataFrame]: + klass, df = class_and_df + return klass(base_dir=dagster_instance.storage_directory()), df diff --git a/tests/example.py b/tests/example.py index d846086..90fa324 100644 --- a/tests/example.py +++ b/tests/example.py @@ -3,12 +3,17 @@ import polars as pl from dagster import Definitions, asset -from dagster_polars import PolarsParquetIOManager +from dagster_polars import PolarsDeltaIOManager, PolarsParquetIOManager @asset(io_manager_def=PolarsParquetIOManager(base_dir="/tmp/dagster")) -def my_asset() -> pl.DataFrame: +def my_parquet_asset() -> pl.DataFrame: return pl.DataFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) -definitions = Definitions(assets=[my_asset]) +@asset(io_manager_def=PolarsDeltaIOManager(base_dir="/tmp/dagster")) +def my_delta_asset() -> pl.DataFrame: + return pl.DataFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) + + +definitions = Definitions(assets=[my_parquet_asset, my_delta_asset]) diff --git a/tests/test_deltalake.py b/tests/test_deltalake.py new file mode 100644 index 0000000..dc96da0 --- /dev/null +++ b/tests/test_deltalake.py @@ -0,0 +1,37 @@ +import shutil + +import polars as pl +import polars.testing as pl_testing +from _pytest.tmpdir import TempPathFactory +from hypothesis import given, settings +from polars.testing.parametric import dataframes + +# TODO: remove pl.Time once it's supported +# TODO: remove pl.Duration pl.Duration once it's supported +# https://github.com/pola-rs/polars/issues/9631 +# TODO: remove UInt types once they are fixed: +# https://github.com/pola-rs/polars/issues/9627 + + +@given( + df=dataframes( + excluded_dtypes=[ + pl.Categorical, + pl.Duration, + pl.Time, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Datetime("ns", None), + ], + min_size=5, + allow_infinities=False, + ) +) +@settings(max_examples=500, deadline=None) +def test_polars_delta_io(df: pl.DataFrame, tmp_path_factory: TempPathFactory): + tmp_path = tmp_path_factory.mktemp("data") + df.write_delta(str(tmp_path)) + pl_testing.assert_frame_equal(df, pl.read_delta(str(tmp_path))) + shutil.rmtree(str(tmp_path)) # cleanup manually because of hypothesis diff --git a/tests/test_polars_delta.py b/tests/test_polars_delta.py new file mode 100644 index 0000000..1cee737 --- /dev/null +++ b/tests/test_polars_delta.py @@ -0,0 +1,160 @@ +import shutil +from typing import Dict + +import polars as pl +import polars.testing as pl_testing +from dagster import OpExecutionContext, StaticPartitionsDefinition, asset, materialize +from deltalake import DeltaTable +from hypothesis import given, settings +from polars.testing.parametric import dataframes + +from dagster_polars import PolarsDeltaIOManager + +# TODO: remove pl.Time once it's supported +# TODO: remove pl.Duration pl.Duration once it's supported +# https://github.com/pola-rs/polars/issues/9631 +# TODO: remove UInt types once they are fixed: +# https://github.com/pola-rs/polars/issues/9627 + + +@given( + df=dataframes( + excluded_dtypes=[ + pl.Categorical, + pl.Duration, + pl.Time, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Datetime("ns", None), + ], + min_size=5, + allow_infinities=False, + ) +) +@settings(max_examples=500, deadline=None) +def test_polars_delta_io_manager(session_polars_delta_io_manager: PolarsDeltaIOManager, df: pl.DataFrame): + @asset(io_manager_def=session_polars_delta_io_manager, metadata={"overwrite_schema": True}) + def upstream() -> pl.DataFrame: + return df + + @asset(io_manager_def=session_polars_delta_io_manager, metadata={"overwrite_schema": True}) + def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: + return upstream.collect(streaming=True) + + result = materialize( + [upstream, downstream], + ) + + handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) + + saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] + assert isinstance(saved_path, str) + pl_testing.assert_frame_equal(df, pl.read_delta(saved_path)) + shutil.rmtree(saved_path) # cleanup manually because of hypothesis + + +def test_polars_delta_io_manager_append(polars_delta_io_manager: PolarsDeltaIOManager): + df = pl.DataFrame( + { + "a": [1, 2, 3], + } + ) + + @asset(io_manager_def=polars_delta_io_manager, metadata={"mode": "append"}) + def append_asset() -> pl.DataFrame: + return df + + result = materialize( + [append_asset], + ) + + handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("append_asset"))) + saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] + assert isinstance(saved_path, str) + + materialize( + [append_asset], + ) + + pl_testing.assert_frame_equal(pl.concat([df, df]), pl.read_delta(saved_path)) + + +def test_polars_delta_io_manager_overwrite_schema(polars_delta_io_manager: PolarsDeltaIOManager): + @asset(io_manager_def=polars_delta_io_manager) + def overwrite_schema_asset() -> pl.DataFrame: # type: ignore + return pl.DataFrame( + { + "a": [1, 2, 3], + } + ) + + result = materialize( + [overwrite_schema_asset], + ) + + handled_output_events = list( + filter(lambda evt: evt.is_handled_output, result.events_for_node("overwrite_schema_asset")) + ) + saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] + assert isinstance(saved_path, str) + + @asset(io_manager_def=polars_delta_io_manager, metadata={"overwrite_schema": True, "mode": "overwrite"}) + def overwrite_schema_asset() -> pl.DataFrame: + return pl.DataFrame( + { + "b": ["1", "2", "3"], + } + ) + + materialize( + [overwrite_schema_asset], + ) + + pl_testing.assert_frame_equal( + pl.DataFrame( + { + "b": ["1", "2", "3"], + } + ), + pl.read_delta(saved_path), + ) + + +def test_polars_delta_native_partitioning(polars_delta_io_manager: PolarsDeltaIOManager, df_for_delta: pl.DataFrame): + manager = polars_delta_io_manager + df = df_for_delta + + partitions_def = StaticPartitionsDefinition(["a", "b"]) + + @asset(io_manager_def=manager, partitions_def=partitions_def, metadata={"partition_by": "partition"}) + def upstream_partitioned(context: OpExecutionContext) -> pl.DataFrame: + return df.with_columns(pl.lit(context.partition_key).alias("partition")) + + @asset(io_manager_def=manager) + def downstream_load_multiple_partitions(upstream_partitioned: Dict[str, pl.LazyFrame]) -> None: + for _df in upstream_partitioned.values(): + assert isinstance(_df, pl.LazyFrame), type(_df) + assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() + + for partition_key in ["a", "b"]: + result = materialize( + [upstream_partitioned], + partition_key=partition_key, + ) + + handled_output_events = list( + filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream_partitioned")) + ) + saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore + assert isinstance(saved_path, str) + assert saved_path.endswith("upstream_partitioned.delta"), saved_path # DeltaLake should handle partitioning! + assert DeltaTable(saved_path).metadata().partition_columns == ["partition"] + + materialize( + [ + upstream_partitioned.to_source_asset(), + downstream_load_multiple_partitions, + ], + ) diff --git a/tests/test_polars_parquet.py b/tests/test_polars_parquet.py index 0b10697..e42773f 100644 --- a/tests/test_polars_parquet.py +++ b/tests/test_polars_parquet.py @@ -1,179 +1,31 @@ -from typing import Dict +import os import polars as pl import polars.testing as pl_testing -from dagster import IOManagerDefinition, OpExecutionContext, StaticPartitionsDefinition, asset, materialize -from deepdiff import DeepDiff +from dagster import asset, materialize from hypothesis import given, settings from polars.testing.parametric import dataframes from dagster_polars import PolarsParquetIOManager -def test_polars_parquet_io_manager_stats_metadata( - tmp_polars_parquet_io_manager: PolarsParquetIOManager, -): - df = pl.DataFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) - - @asset(io_manager_key="polars_parquet_io_manager") - def upstream() -> pl.DataFrame: - return df - - result = materialize( - [upstream], - resources={"polars_parquet_io_manager": tmp_polars_parquet_io_manager}, - ) - - handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) - - stats = handled_output_events[0].event_specific_data.metadata["stats"].value # type: ignore # noqa - - assert ( - DeepDiff( - stats, - { - "a": { - "count": 3.0, - "null_count": 1.0, - "mean": 0.5, - "std": 0.7071067811865476, - "min": 0.0, - "max": 1.0, - "median": 0.5, - "25%": 0.0, - "75%": 1.0, - }, - "b": { - "count": "3", - "null_count": "0", - "mean": "null", - "std": "null", - "min": "a", - "max": "c", - "median": "null", - "25%": "null", - "75%": "null", - }, - }, - ) - == {} - ) - - # allowed_dtypes=[pl.List(inner) for inner in # list(pl.TEMPORAL_DTYPES | pl.FLOAT_DTYPES | pl.INTEGER_DTYPES) + [pl.Boolean, pl.Utf8]] @given(df=dataframes(excluded_dtypes=[pl.Categorical], min_size=5)) @settings(max_examples=100, deadline=None) -def test_polars_parquet_io_manager(tmp_polars_parquet_io_manager: PolarsParquetIOManager, df: pl.DataFrame): - @asset(io_manager_key="polars_parquet_io_manager") - def upstream() -> pl.DataFrame: - return df - - @asset(io_manager_key="polars_parquet_io_manager") - def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: - return upstream.collect(streaming=True) - - result = materialize( - [upstream, downstream], - resources={"polars_parquet_io_manager": tmp_polars_parquet_io_manager}, - ) - - handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) - - saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] - assert isinstance(saved_path, str) - pl_testing.assert_frame_equal(df, pl.read_parquet(saved_path)) - - -def test_polars_parquet_io_manager_nested_dtypes( - tmp_polars_parquet_io_manager: PolarsParquetIOManager, df: pl.DataFrame +def test_polars_parquet_io_manager_read_write( + session_polars_parquet_io_manager: PolarsParquetIOManager, df: pl.DataFrame ): - @asset(io_manager_key="polars_parquet_io_manager") + @asset(io_manager_def=session_polars_parquet_io_manager) def upstream() -> pl.DataFrame: return df - @asset(io_manager_key="polars_parquet_io_manager") + @asset(io_manager_def=session_polars_parquet_io_manager) def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: return upstream.collect(streaming=True) result = materialize( [upstream, downstream], - resources={"polars_parquet_io_manager": tmp_polars_parquet_io_manager}, - ) - - handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) - - saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] - assert isinstance(saved_path, str) - pl_testing.assert_frame_equal(df, pl.read_parquet(saved_path)) - - -def test_polars_parquet_io_manager_type_annotations( - tmp_polars_parquet_io_manager_legacy: IOManagerDefinition, df: pl.DataFrame -): - @asset(io_manager_key="polars_parquet_io_manager") - def upstream() -> pl.DataFrame: - return df - - @asset - def downstream_default_eager(upstream) -> None: - assert isinstance(upstream, pl.DataFrame), type(upstream) - - @asset - def downstream_eager(upstream: pl.DataFrame) -> None: - assert isinstance(upstream, pl.DataFrame), type(upstream) - - @asset - def downstream_lazy(upstream: pl.LazyFrame) -> None: - assert isinstance(upstream, pl.LazyFrame), type(upstream) - - partitions_def = StaticPartitionsDefinition(["a", "b"]) - - @asset(io_manager_key="polars_parquet_io_manager", partitions_def=partitions_def) - def upstream_partitioned(context: OpExecutionContext) -> pl.DataFrame: - return df.with_columns(pl.lit(context.partition_key).alias("partition")) - - @asset - def downstream_multi_partitioned_eager(upstream_partitioned: Dict[str, pl.DataFrame]) -> None: - for _df in upstream_partitioned.values(): - assert isinstance(_df, pl.DataFrame), type(_df) - assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() - - @asset - def downstream_multi_partitioned_lazy(upstream_partitioned: Dict[str, pl.LazyFrame]) -> None: - for _df in upstream_partitioned.values(): - assert isinstance(_df, pl.LazyFrame), type(_df) - assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() - - for partition_key in ["a", "b"]: - materialize( - [upstream_partitioned], - resources={"polars_parquet_io_manager": tmp_polars_parquet_io_manager_legacy}, - partition_key=partition_key, - ) - - materialize( - [ - upstream_partitioned.to_source_asset(), - upstream, - downstream_default_eager, - downstream_eager, - downstream_lazy, - downstream_multi_partitioned_eager, - downstream_multi_partitioned_lazy, - ], - resources={"polars_parquet_io_manager": tmp_polars_parquet_io_manager_legacy}, - ) - - -def test_polars_parquet_io_manager_legacy(tmp_polars_parquet_io_manager_legacy: IOManagerDefinition, df: pl.DataFrame): - @asset(io_manager_key="polars_parquet_io_manager") - def upstream() -> pl.DataFrame: - return df - - result = materialize( - [upstream], - resources={"polars_parquet_io_manager": tmp_polars_parquet_io_manager_legacy}, ) handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) @@ -181,3 +33,4 @@ def upstream() -> pl.DataFrame: saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] assert isinstance(saved_path, str) pl_testing.assert_frame_equal(df, pl.read_parquet(saved_path)) + os.remove(saved_path) # cleanup manually because of hypothesis diff --git a/tests/test_upath_io_managers.py b/tests/test_upath_io_managers.py new file mode 100644 index 0000000..bcc65cc --- /dev/null +++ b/tests/test_upath_io_managers.py @@ -0,0 +1,141 @@ +from typing import Dict, Tuple + +import polars as pl +import polars.testing as pl_testing +from dagster import OpExecutionContext, StaticPartitionsDefinition, asset, materialize +from deepdiff import DeepDiff + +from dagster_polars import BasePolarsUPathIOManager, PolarsDeltaIOManager, PolarsParquetIOManager + + +def test_polars_upath_io_manager_stats_metadata(io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame]): + manager, _ = io_manager_and_df + + df = pl.DataFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) + + @asset(io_manager_def=manager) + def upstream() -> pl.DataFrame: + return df + + result = materialize( + [upstream], + ) + + handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) + + stats = handled_output_events[0].event_specific_data.metadata["stats"].value # type: ignore # noqa + + assert ( + DeepDiff( + stats, + { + "a": { + "count": 3.0, + "null_count": 1.0, + "mean": 0.5, + "std": 0.7071067811865476, + "min": 0.0, + "max": 1.0, + "median": 0.5, + "25%": 0.0, + "75%": 1.0, + }, + "b": { + "count": "3", + "null_count": "0", + "mean": "null", + "std": "null", + "min": "a", + "max": "c", + "median": "null", + "25%": "null", + "75%": "null", + }, + }, + ) + == {} + ) + + +def test_polars_upath_io_manager_type_annotations(io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame]): + manager, df = io_manager_and_df + + @asset(io_manager_def=manager) + def upstream() -> pl.DataFrame: + return df + + @asset(io_manager_def=manager) + def downstream_default_eager(upstream) -> None: + assert isinstance(upstream, pl.DataFrame), type(upstream) + + @asset(io_manager_def=manager) + def downstream_eager(upstream: pl.DataFrame) -> None: + assert isinstance(upstream, pl.DataFrame), type(upstream) + + @asset(io_manager_def=manager) + def downstream_lazy(upstream: pl.LazyFrame) -> None: + assert isinstance(upstream, pl.LazyFrame), type(upstream) + + partitions_def = StaticPartitionsDefinition(["a", "b"]) + + @asset(io_manager_def=manager, partitions_def=partitions_def) + def upstream_partitioned(context: OpExecutionContext) -> pl.DataFrame: + return df.with_columns(pl.lit(context.partition_key).alias("partition")) + + @asset(io_manager_def=manager) + def downstream_multi_partitioned_eager(upstream_partitioned: Dict[str, pl.DataFrame]) -> None: + for _df in upstream_partitioned.values(): + assert isinstance(_df, pl.DataFrame), type(_df) + assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() + + @asset(io_manager_def=manager) + def downstream_multi_partitioned_lazy(upstream_partitioned: Dict[str, pl.LazyFrame]) -> None: + for _df in upstream_partitioned.values(): + assert isinstance(_df, pl.LazyFrame), type(_df) + assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() + + for partition_key in ["a", "b"]: + materialize( + [upstream_partitioned], + partition_key=partition_key, + ) + + materialize( + [ + upstream_partitioned.to_source_asset(), + upstream, + downstream_default_eager, + downstream_eager, + downstream_lazy, + downstream_multi_partitioned_eager, + downstream_multi_partitioned_lazy, + ], + ) + + +def test_polars_upath_io_manager_nested_dtypes(io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame]): + manager, df = io_manager_and_df + + @asset(io_manager_def=manager) + def upstream() -> pl.DataFrame: + return df + + @asset(io_manager_def=manager) + def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: + return upstream.collect(streaming=True) + + result = materialize( + [upstream, downstream], + ) + + handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) + + saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] + assert isinstance(saved_path, str) + + if isinstance(manager, PolarsParquetIOManager): + pl_testing.assert_frame_equal(df, pl.read_parquet(saved_path)) + elif isinstance(manager, PolarsDeltaIOManager): + pl_testing.assert_frame_equal(df, pl.read_delta(saved_path)) + else: + raise ValueError(f"Test not implemented for {type(manager)}")