Skip to content

Commit

Permalink
Bugfix/583: better handling of datetime/timedelta in serialize/deseri…
Browse files Browse the repository at this point in the history
…alize (#585)

* yaml serialize/deserialize support non-datetime/timedelta check stats

* add test
  • Loading branch information
cosmicBboy authored Aug 6, 2021
1 parent 589d4a9 commit bada11d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
26 changes: 18 additions & 8 deletions pandera/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,17 @@ def _serialize_check_stats(check_stats, dtype=None):
"""Serialize check statistics into json/yaml-compatible format."""

def handle_stat_dtype(stat):
if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype):
if pandas_engine.Engine.dtype(dtypes.DateTime).check(
dtype
) and hasattr(stat, "strftime"):
# try serializing stat as a string if it's datetime-like,
# otherwise return original value
return stat.strftime(DATETIME_FORMAT)
elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(dtype):
# serialize to int in nanoseconds
elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(
dtype
) and hasattr(stat, "delta"):
# try serializing stat into an int in nanoseconds if it's
# timedelta-like, otherwise return original value
return stat.delta

return stat
Expand Down Expand Up @@ -146,11 +153,14 @@ def _serialize_schema(dataframe_schema):

def _deserialize_check_stats(check, serialized_check_stats, dtype=None):
def handle_stat_dtype(stat):
if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype):
return pd.to_datetime(stat, format=DATETIME_FORMAT)
elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(dtype):
# serialize to int in nanoseconds
return pd.to_timedelta(stat, unit="ns")
try:
if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype):
return pd.to_datetime(stat, format=DATETIME_FORMAT)
elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(dtype):
# serialize to int in nanoseconds
return pd.to_timedelta(stat, unit="ns")
except (TypeError, ValueError):
return stat
return stat

if isinstance(serialized_check_stats, dict):
Expand Down
32 changes: 30 additions & 2 deletions tests/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,8 @@ def test_to_yaml_custom_dataframe_check():
# `test_to_yaml_lambda_check`


def test_to_yaml_bugfix_419():
"""Ensure that GH#419 is fixed"""
def test_to_yaml_bugfix_warn_unregistered_global_checks():
"""Ensure that unregistered global checks raises a warning."""
# pylint: disable=no-self-use

class CheckedSchemaModel(pandera.SchemaModel):
Expand All @@ -628,6 +628,34 @@ def unregistered_check(self, _):
CheckedSchemaModel.to_yaml()


def test_serialize_deserialize_custom_datetime_checks():
"""
Test that custom checks for datetime columns can be serialized and
deserialized
"""

# pylint: disable=unused-variable,unused-argument
@pandera.extensions.register_check_method(statistics=["stat"])
def datetime_check(pandas_obj, *, stat):
...

schema = pandera.DataFrameSchema(
{
"dt_col": pandera.Column(
pandera.DateTime,
checks=pandera.Check.datetime_check("foobar"),
),
"td_col": pandera.Column(
pandera.Timedelta,
checks=pandera.Check.datetime_check("foobar"),
),
}
)
yaml_schema = schema.to_yaml()
schema_from_yaml = schema.from_yaml(yaml_schema)
assert schema_from_yaml == schema


FRICTIONLESS_YAML = yaml.safe_load(
"""
fields:
Expand Down

0 comments on commit bada11d

Please sign in to comment.