From 2e33cb8f5721d2236e1a64ac63a2f9b1dec28a52 Mon Sep 17 00:00:00 2001 From: ion-elgreco <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 28 Jan 2024 11:36:31 +0100 Subject: [PATCH] fix overloads --- dagster_polars/io_managers/delta.py | 14 +++++++++++++- dagster_polars/io_managers/parquet.py | 20 ++++++++++++++++---- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/dagster_polars/io_managers/delta.py b/dagster_polars/io_managers/delta.py index de48bc9..3418ab5 100644 --- a/dagster_polars/io_managers/delta.py +++ b/dagster_polars/io_managers/delta.py @@ -1,7 +1,7 @@ import json from enum import Enum from pprint import pformat -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, overload import dagster._check as check import polars as pl @@ -241,6 +241,18 @@ def write_df_to_path( metadata_path.parent.mkdir(parents=True, exist_ok=True) metadata_path.write_text(json.dumps(metadata)) + @overload + def scan_df_from_path( + self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] + ) -> pl.LazyFrame: + ... + + @overload + def scan_df_from_path( + self, path: "UPath", context: InputContext, with_metadata: Literal[True] + ) -> LazyFrameWithMetadata: + ... + def scan_df_from_path( self, path: "UPath", diff --git a/dagster_polars/io_managers/parquet.py b/dagster_polars/io_managers/parquet.py index 3b0389e..5bbf8cb 100644 --- a/dagster_polars/io_managers/parquet.py +++ b/dagster_polars/io_managers/parquet.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload import polars as pl import pyarrow.dataset as ds @@ -145,7 +145,7 @@ def downsteam(upstream: DataFrameWithMetadata): assert metadata["my_custom_metadata"] == "my_custom_value" """ - extension: str = ".parquet" + extension: str = ".parquet" # type: ignore def sink_df_to_path( self, @@ -256,12 +256,24 @@ def write_df_to_path( row_group_size=row_group_size, ) - def scan_df_from_path( # type: ignore + @overload + def scan_df_from_path( + self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] + ) -> pl.LazyFrame: + ... + + @overload + def scan_df_from_path( + self, path: "UPath", context: InputContext, with_metadata: Literal[True] + ) -> LazyFrameWithMetadata: + ... + + def scan_df_from_path( self, path: "UPath", context: InputContext, - partition_key: Optional[str] = None, with_metadata: Optional[bool] = False, + partition_key: Optional[str] = None, ) -> Union[pl.LazyFrame, LazyFrameWithMetadata]: ldf = scan_parquet(path, context)