Skip to content

Commit

Permalink
Bugfix/1677 Fix Pandera DataFrame - Pydantic compatibility (#1904)
Browse files Browse the repository at this point in the history
* fix DataFrame Pydantic compatibility

* format python file

* update test for new code

* prevents Linters from raising an error

Signed-off-by: Jarek-Rolski <[email protected]>

* enable pyarrow and other types in pydantic models

Signed-off-by: Jarek-Rolski <[email protected]>

---------

Signed-off-by: Jarek-Rolski <[email protected]>
  • Loading branch information
Jarek-Rolski authored Feb 10, 2025
1 parent 70a49d5 commit 754e66d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
32 changes: 25 additions & 7 deletions pandera/typing/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
SeriesBase,
)
from pandera.typing.formats import Formats

try:
from typing import get_args
except ImportError:
from typing_extensions import get_args

from pandera.config import config_context

try:
from typing import _GenericAlias # type: ignore[attr-defined]
Expand Down Expand Up @@ -191,12 +186,35 @@ def _get_schema_model(cls, field):
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
schema_model = get_args(_source_type)[0]
# prevent validation in __setattr__ function in DataFrameBase class
with config_context(validation_enabled=False):
schema_model = _source_type().__orig_class__.__args__[0]
schema = schema_model.to_schema()
schema_json_columns = schema_model.to_json_schema()["properties"]
type_map = {
"string": core_schema.str_schema(),
"integer": core_schema.int_schema(),
"number": core_schema.float_schema(),
"boolean": core_schema.bool_schema(),
"datetime": core_schema.datetime_schema(),
}
return core_schema.no_info_plain_validator_function(
functools.partial(
cls.pydantic_validate,
schema_model=schema_model,
),
json_schema_input_schema=core_schema.list_schema(
core_schema.typed_dict_schema(
{
key: core_schema.typed_dict_field(
type_map[
schema_json_columns[key]["items"]["type"]
]
)
for key in schema.columns.keys()
},
)
),
)

else:
Expand Down
11 changes: 7 additions & 4 deletions tests/core/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_typed_dataframe():


def test_invalid_typed_dataframe():
"""Test that an invalid typed DataFrame is recognized by pydantic."""
"""Test that an invalid typed DataFrame is recognized by pandera."""
with pytest.raises(ValidationError):
TypedDfPydantic(df=1)

Expand All @@ -74,10 +74,13 @@ class InvalidSchema(pa.DataFrameModel):

str_col = pa.Field(unique=True) # omit annotation

class PydanticModel(BaseModel):
pa_schema: DataFrame[InvalidSchema]
with pytest.raises(pa.errors.SchemaInitError):

class PydanticModel(BaseModel):
pa_schema: DataFrame[InvalidSchema]

with pytest.raises(ValueError):
# This check prevents Linters from raising an error about not using the PydanticModel class
with pytest.raises(UnboundLocalError):
PydanticModel(pa_schema=InvalidSchema)


Expand Down

0 comments on commit 754e66d

Please sign in to comment.