Skip to content

Commit

Permalink
bugfix: timezone-agnostic datetime in polars works in DataFrameModel (#…
Browse files Browse the repository at this point in the history
…1638)

Signed-off-by: cosmicBboy <[email protected]>
  • Loading branch information
cosmicBboy committed May 14, 2024
1 parent 95e412f commit d2bfed0
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 5 deletions.
6 changes: 5 additions & 1 deletion docs/source/polars.md
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,11 @@ from pandera.engines.polars_engine import DateTime
class DateTimeModel(pa.DataFrameModel):
created_at: Annotated[DateTime, True]
created_at: Annotated[DateTime, True, "us", None]
```
.
```{note}
For `Annotated` types, you need to pass in all positional and keyword arguments.
```

:::
Expand Down
15 changes: 11 additions & 4 deletions pandera/api/polars/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Class-based api for polars models."""

import inspect
from typing import Dict, List, Tuple, Type

import pandas as pd
Expand Down Expand Up @@ -47,10 +48,16 @@ def _build_columns( # pylint:disable=too-many-locals
field_name = field.name
check_name = getattr(field, "check_name", None)

engine_dtype = None
try:
engine_dtype = pe.Engine.dtype(annotation.raw_annotation)
dtype = engine_dtype.type
if inspect.isclass(annotation.raw_annotation) and issubclass(
annotation.raw_annotation, pe.DataType
):
# use the raw annotation as the dtype if it's a native
# pandera polars datatype
dtype = annotation.raw_annotation
else:
dtype = engine_dtype.type
except (TypeError, ValueError) as exc:
if annotation.metadata:
if field.dtype_kwargs:
Expand All @@ -64,13 +71,13 @@ def _build_columns( # pylint:disable=too-many-locals
elif annotation.default_dtype:
dtype = annotation.default_dtype
else:
dtype = annotation.arg
dtype = annotation.arg # type: ignore

if (
annotation.origin is None
or isinstance(annotation.origin, pl.datatypes.DataTypeClass)
or annotation.origin is Series
or engine_dtype
or dtype
):
if check_name is False:
raise SchemaInitError(
Expand Down
59 changes: 59 additions & 0 deletions tests/polars/test_polars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
import sys
from typing import Optional

try: # python 3.9+
from typing import Annotated # type: ignore
except ImportError:
from typing_extensions import Annotated # type: ignore

import polars as pl
import pytest
from hypothesis import given
from hypothesis import strategies as st
from polars.testing.parametric import column, dataframes

import pandera.engines.polars_engine as pe
from pandera.errors import SchemaError
from pandera.polars import (
Column,
Expand Down Expand Up @@ -211,3 +220,53 @@ class ModelWithNestedDtypes(DataFrameModel):

schema = ModelWithNestedDtypes.to_schema()
assert schema_with_list_type == schema


@pytest.mark.parametrize(
"time_zone",
[
None,
"UTC",
"GMT",
"EST",
],
)
@given(st.data())
def test_dataframe_schema_with_tz_agnostic_dates(time_zone, data):
strategy = dataframes(
column("datetime_col", dtype=pl.Datetime()),
lazy=True,
size=10,
)
lf = data.draw(strategy)
lf = lf.cast({"datetime_col": pl.Datetime(time_zone=time_zone)})

class ModelTZAgnosticKwargs(DataFrameModel):
datetime_col: pe.DateTime = Field(
dtype_kwargs={"time_zone_agnostic": True}
)

class ModelTZSensitiveKwargs(DataFrameModel):
datetime_col: pe.DateTime = Field(
dtype_kwargs={"time_zone_agnostic": False}
)

class ModelTZAgnosticAnnotated(DataFrameModel):
datetime_col: Annotated[pe.DateTime, True, "us", None]

class ModelTZSensitiveAnnotated(DataFrameModel):
datetime_col: Annotated[pe.DateTime, False, "us", None]

for tz_agnostic_model in (
ModelTZAgnosticKwargs,
ModelTZAgnosticAnnotated,
):
tz_agnostic_model.validate(lf)

for tz_sensitive_model in (
ModelTZSensitiveKwargs,
ModelTZSensitiveAnnotated,
):
if time_zone:
with pytest.raises(SchemaError):
tz_sensitive_model.validate(lf)

0 comments on commit d2bfed0

Please sign in to comment.