Skip to content
This repository has been archived by the owner on Mar 11, 2024. It is now read-only.

Feat/add polars delta merge support #47

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
- "0.17.0"
- "0.18.0"
- "0.19.0"
- "0.20.1"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 0.20.0 would be more in line with the others

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I understand but polars merge was added in 0.20 is okay to change to 0.20?

Copy link
Owner

@danielgafni danielgafni Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.20.0 in this contest means ">=0.20.0, <0.21.0", or "latest available before 0.21.0". That's how CI is set up. So yes, it's ok to do 0.20.0 here.

steps:
- name: Setup python for test ${{ matrix.py }}
uses: actions/setup-python@v2
Expand Down Expand Up @@ -81,6 +82,7 @@ jobs:
- "0.17.0"
- "0.18.0"
- "0.19.0"
- "0.20.0" #minimal version for delta merge
steps:
- name: Setup python for test ${{ matrix.py }}
uses: actions/setup-python@v2
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ repos:
entry: pyright .
language: system
pass_filenames: false
language: system
pass_filenames: false
33 changes: 25 additions & 8 deletions dagster_polars/io_managers/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DeltaWriteMode(str, Enum):
append = "append"
overwrite = "overwrite"
ignore = "ignore"
merge = "merge"


class PolarsDeltaIOManager(BasePolarsUPathIOManager):
Expand All @@ -49,7 +50,10 @@ def dump_df_to_path(
):
assert context.metadata is not None

delta_write_options = context.metadata.get("delta_write_options")
if context.metadata.get("mode") or self.mode != "merge":
delta_write_options = context.metadata.get("delta_write_options")
else:
delta_merge_options = context.metadata.get("delta_merge_options")

if context.has_asset_partitions:
delta_write_options = delta_write_options or {}
Expand All @@ -63,13 +67,26 @@ def dump_df_to_path(

storage_options = self.get_storage_options(path)

df.write_delta(
str(path),
mode=context.metadata.get("mode") or self.mode, # type: ignore
overwrite_schema=context.metadata.get("overwrite_schema") or self.overwrite_schema,
storage_options=storage_options,
delta_write_options=delta_write_options,
)
if context.metadata.get("mode") or self.mode != "merge":
df.write_delta(
str(path),
mode=context.metadata.get("mode") or self.mode, # type: ignore
overwrite_schema=context.metadata.get("overwrite_schema") or self.overwrite_schema,
storage_options=storage_options,
delta_write_options=delta_write_options,
)
else:
(
df.write_delta(
str(path),
mode=context.metadata.get("mode") or self.mode, # type: ignore
storage_options=storage_options,
delta_merge_options=delta_merge_options,
)
.when_matched_update_all()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be configurable in some way, basically this is a default upsert, but MERGEs can be complex set of different update, delete and insert operations.

I commonly use deduplicate on insert

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Ion,

You are right and if you look to my example you will find a deduplication strategy using a rank function over a primary key and then selecting the first row. However for that the input dataset needs to have a "cdc" column (like load dts).

Shouldn´t this be responsibility of the user?

An alternative could be to modify:

 def get_metadata(self, context: OutputContext, obj: pl.DataFrame) -> Dict[str, MetadataValue]:
        assert context.metadata is not None

        metadata = super().get_metadata(context, obj)

        if context.has_asset_partitions:
            partition_by = context.metadata.get("partition_by")
            if partition_by is not None:
                metadata["partition_by"] = partition_by

        if context.metadata.get("mode") == "append":
            # modify the medatata to reflect the fact that we are appending to the table

            if context.has_asset_partitions:
                # paths = self._get_paths_for_partitions(context)
                # assert len(paths) == 1
                # path = list(paths.values())[0]

                # FIXME: what to about row_count metadata do if we are appending to a partitioned table?
                # we should not be using the full table length,
                # but it's unclear how to get the length of the partition we are appending to
                pass
            else:
                metadata["append_row_count"] = metadata["row_count"]

                path = self._get_path(context)
                # we need to get row_count from the full table
                metadata["row_count"] = MetadataValue.int(
                    DeltaTable(str(path), storage_options=self.get_storage_options(path))
                    .to_pyarrow_dataset()
                    .count_rows()
                )

        return metadata

To maybe do something like this:

        if context.metadata.get("mode") == "append":
            # modify the medatata to reflect the fact that we are appending to the table

            if context.has_asset_partitions:
                # paths = self._get_paths_for_partitions(context)
                # assert len(paths) == 1
                # path = list(paths.values())[0]

                # FIXME: what to about row_count metadata do if we are appending to a partitioned table?
                # we should not be using the full table length,
                # but it's unclear how to get the length of the partition we are appending to
                pass
            else:
                metadata["append_row_count"] = metadata["row_count"]
       if context.metadata.get("mode") == "merge":
            # modify the medatata to reflect the fact that we are appending to the table
            metadata["primary_key"] == "something here that refers to this key"
            metadata["cdc_column"] == "something here that refers to this key"

                path = self._get_path(context)
                # we need to get row_count from the full table
                metadata["row_count"] = MetadataValue.int(
                    DeltaTable(str(path), storage_options=self.get_storage_options(path))
                    .to_pyarrow_dataset()
                    .count_rows()
                )

        return metadata

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey :)

Yeah, I need to go a bit more through the current implementation of dagster-polars. I've already pushed a PR for dagster-deltalake-polars as a first step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edgBR for my own work I am planning to use the dagster-deltalake-polars and then only the parquet IO manager in dagster-polars.

So somewhere next week after my first PR get's merged in dagster-deltalake-polars I will expand it there to cover a couple common MERGE strategies

.when_not_matched_insert_all()
.execute()
)
current_version = DeltaTable(str(path), storage_options=storage_options).version()
context.add_output_metadata({"version": current_version})

Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ name = "dagster-polars"
version = "0.0.0"
description = "Dagster integration library for Polars"
authors = [
"Daniel Gafni <[email protected]>"
"Daniel Gafni <[email protected]>",
"Edgar Bahilo <[email protected]>"
]
readme = "README.md"
packages = [{include = "dagster_polars"}]
Expand All @@ -28,11 +29,11 @@ license = "Apache-2.0"
[tool.poetry.dependencies]
python = "^3.8"
dagster = "^1.4.0"
polars = ">=0.17.0"
polars = ">=0.20.1"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not change the lower polars constraint. We don't want to force an update for users as it can break their code.

Copy link
Owner

@danielgafni danielgafni Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, we do want to update the dev polars version pinned in poetry.lock. This can be done via "poetry update polars" command.

pyarrow = ">=8.0.0"
typing-extensions = "^4.7.1"

deltalake = { version = ">=0.10.0", optional = true }
deltalake = { version = ">=0.14.0", optional = true }
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, I don't think we need to change this

dagster-gcp = { version = ">=0.19.5", optional = true }
universal-pathlib = "^0.1.4"

Expand Down
11 changes: 10 additions & 1 deletion tests/test_polars_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ def append_asset() -> pl.DataFrame:

pl_testing.assert_frame_equal(pl.concat([df, df]), pl.read_delta(saved_path))


def test_polars_delta_io_manager_merge(polars_delta_io_manager: PolarsDeltaIOManager):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for the new merge functionality. You can take a look at other tests for inspiration. Let me know if you need any help with this

df = pl.DataFrame(
{
"a": [1, 2, 3],
}
)
x = "hello"
assert "hello" == x
####

def test_polars_delta_io_manager_overwrite_schema(
polars_delta_io_manager: PolarsDeltaIOManager, dagster_instance: DagsterInstance
):
Expand Down