Skip to content

Commit

Permalink
Bugfix/763 improve type annotations for DataFrameModel.validate (#1905)
Browse files Browse the repository at this point in the history
* trial type annotations

Signed-off-by: Matt Richards <[email protected]>

* changes in individual api files

Signed-off-by: Matt Richards <[email protected]>

* pl.dataframe working in local test

Signed-off-by: Matt Richards <[email protected]>

* older python union compat

Signed-off-by: Matt Richards <[email protected]>

* try polars in the mypy env on ci

Signed-off-by: Matt Richards <[email protected]>

* translate toplevel mypy skip into module specific skips

Signed-off-by: Matt Richards <[email protected]>

* mypy passes

Signed-off-by: Matt Richards <[email protected]>

* missing line continuation

Signed-off-by: Matt Richards <[email protected]>

* python 3.8

Signed-off-by: Matt Richards <[email protected]>

---------

Signed-off-by: Matt Richards <[email protected]>
  • Loading branch information
m-richards authored Feb 17, 2025
1 parent 32b08fd commit 7186507
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 9 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ jobs:
types-pytz \
types-pyyaml \
types-requests \
types-setuptools
types-setuptools \
polars
- name: Pip info
run: python -m pip list

Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ repos:
- types-pyyaml
- types-requests
- types-setuptools
- polars
args: ["pandera", "tests", "scripts"]
exclude: (^docs/|^tests/mypy/modules/)
pass_filenames: false
Expand Down
16 changes: 15 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
ignore_missing_imports = True
follow_imports = skip
follow_imports = normal
allow_redefinition = True
warn_return_any = False
warn_unused_configs = True
Expand All @@ -12,3 +12,17 @@ exclude=(?x)(
| ^pandera/backends/pyspark
| ^tests/pyspark
)
[mypy-pandera.api.pyspark.*]
follow_imports = skip

[mypy-docs.*]
follow_imports = skip

[mypy-pandera.engines.polars_engine]
ignore_errors = True

[mypy-pandera.backends.polars.builtin_checks]
ignore_errors = True

[mypy-tests.polars.*]
ignore_errors = True
24 changes: 23 additions & 1 deletion pandera/api/pandas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import copy
import sys
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast

import pandas as pd

from pandera.api.base.schema import BaseSchema
from pandera.api.checks import Check
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
from pandera.api.dataframe.model import get_dtype_kwargs
Expand All @@ -22,6 +23,7 @@
AnnotationInfo,
DataFrame,
)
from pandera.utils import docstring_substitution

# if python version is < 3.11, import Self from typing_extensions
if sys.version_info < (3, 11):
Expand Down Expand Up @@ -171,6 +173,26 @@ def _build_columns_index( # pylint:disable=too-many-locals,too-many-branches

return columns, _build_schema_index(indices, **multiindex_kwargs)

@classmethod
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
def validate(
cls: Type[Self],
check_obj: pd.DataFrame,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> DataFrame[Self]:
"""%(validate_doc)s"""
return cast(
DataFrame[Self],
cls.to_schema().validate(
check_obj, head, tail, sample, random_state, lazy, inplace
),
)

@classmethod
def to_json_schema(cls):
"""Serialize schema metadata into json-schema format.
Expand Down
2 changes: 1 addition & 1 deletion pandera/api/polars/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Column(ComponentSchema[PolarsCheckObjects]):

def __init__(
self,
dtype: PolarsDtypeInputTypes = None,
dtype: Optional[PolarsDtypeInputTypes] = None,
checks: Optional[CheckList] = None,
nullable: bool = False,
unique: bool = False,
Expand Down
54 changes: 52 additions & 2 deletions pandera/api/polars/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Class-based api for polars models."""

import inspect
from typing import Dict, List, Tuple, Type
from typing import Dict, List, Tuple, Type, cast, Optional, overload, Union
from typing_extensions import Self

import pandas as pd
import polars as pl

from pandera.api.base.schema import BaseSchema
from pandera.api.checks import Check
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
from pandera.api.dataframe.model import get_dtype_kwargs
Expand All @@ -16,7 +18,8 @@
from pandera.engines import polars_engine as pe
from pandera.errors import SchemaInitError
from pandera.typing import AnnotationInfo
from pandera.typing.polars import Series
from pandera.typing.polars import Series, LazyFrame, DataFrame
from pandera.utils import docstring_substitution


class DataFrameModel(_DataFrameModel[pl.LazyFrame, DataFrameSchema]):
Expand Down Expand Up @@ -109,6 +112,53 @@ def _build_columns( # pylint:disable=too-many-locals

return columns

@classmethod
@overload
def validate(
cls: Type[Self],
check_obj: pl.DataFrame,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> DataFrame[Self]: ...

@classmethod
@overload
def validate(
cls: Type[Self],
check_obj: pl.LazyFrame,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> LazyFrame[Self]: ...

@classmethod
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
def validate(
cls: Type[Self],
check_obj: Union[pl.LazyFrame, pl.DataFrame],
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> Union[LazyFrame[Self], DataFrame[Self]]:
"""%(validate_doc)s"""
result = cls.to_schema().validate(
check_obj, head, tail, sample, random_state, lazy, inplace
)
if isinstance(check_obj, pl.LazyFrame):
return cast(LazyFrame[Self], result)
else:
return cast(DataFrame[Self], result)

@classmethod
def to_json_schema(cls):
"""Serialize schema metadata into json-schema format.
Expand Down
5 changes: 3 additions & 2 deletions pandera/api/pyspark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pandera.errors import SchemaInitError
from pandera.typing import AnnotationInfo
from pandera.typing.common import DataFrameBase
from pandera.typing.pyspark import DataFrame

try:
from typing_extensions import get_type_hints
Expand Down Expand Up @@ -300,10 +301,10 @@ def validate(
random_state: Optional[int] = None,
lazy: bool = True,
inplace: bool = False,
) -> Optional[DataFrameBase[TDataFrameModel]]:
) -> DataFrame[TDataFrameModel]:
"""%(validate_doc)s"""
return cast(
DataFrameBase[TDataFrameModel],
DataFrame[TDataFrameModel],
cls.to_schema().validate(
check_obj, head, tail, sample, random_state, lazy, inplace
),
Expand Down
5 changes: 4 additions & 1 deletion pandera/backends/polars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def subsample(
obj_subsample.append(check_obj.tail(tail))
if sample is not None:
obj_subsample.append(
check_obj.sample(sample, random_state=random_state)
# mypy is detecting a bug https://github.com/unionai-oss/pandera/issues/1912
check_obj.sample( # type:ignore [attr-defined]
sample, random_state=random_state
)
)
return (
check_obj
Expand Down

0 comments on commit 7186507

Please sign in to comment.