Skip to content

Commit

Permalink
fix to_script output (#994)
Browse files Browse the repository at this point in the history
* fix to_script output

use string alias representation when writing schame to a script.
string aliases are more portable and don't rely on specifics of the
pandas API, especially for parameterized datatypes.

Signed-off-by: Niels Bantilan <[email protected]>

* include other DataFrameSchema attributes in serialization

Signed-off-by: Niels Bantilan <[email protected]>

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy authored Nov 1, 2022
1 parent 6da537f commit ba04d1b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 17 deletions.
32 changes: 26 additions & 6 deletions docs/source/schema_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr
schema = DataFrameSchema(
columns={
"column1": Column(
dtype=pandera.engines.numpy_engine.Int64,
dtype="int64",
checks=[
Check.greater_than_or_equal_to(min_value=5.0),
Check.less_than_or_equal_to(max_value=20.0),
Expand All @@ -116,7 +116,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr
title=None,
),
"column2": Column(
dtype=pandera.engines.numpy_engine.Object,
dtype="object",
checks=None,
nullable=False,
unique=False,
Expand All @@ -127,7 +127,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr
title=None,
),
"column3": Column(
dtype=pandera.engines.pandas_engine.DateTime,
dtype="datetime64[ns]",
checks=[
Check.greater_than_or_equal_to(
min_value=Timestamp("2010-01-01 00:00:00")
Expand All @@ -145,8 +145,9 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr
title=None,
),
},
checks=None,
index=Index(
dtype=pandera.engines.numpy_engine.Int64,
dtype="int64",
checks=[
Check.greater_than_or_equal_to(min_value=0.0),
Check.less_than_or_equal_to(max_value=2.0),
Expand All @@ -157,9 +158,16 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr
description=None,
title=None,
),
dtype=None,
coerce=True,
strict=False,
name=None,
ordered=False,
unique=None,
report_duplicates="all",
unique_column_names=False,
title=None,
description=None,
)

As a python script, you can iterate on an inferred schema and use it to
Expand Down Expand Up @@ -234,10 +242,16 @@ is a convenience method for this functionality.
name: null
unique: false
coerce: false
dtype: null
coerce: true
strict: false
unique: null
name: null
ordered: false
unique: null
report_duplicates: all
unique_column_names: false
title: null
description: null

You can edit this yaml file to modify the schema. For example, you can specify
new column names under the ``column`` key, and the respective values map onto
Expand Down Expand Up @@ -328,10 +342,16 @@ is a convenience method for this functionality.
"coerce": false
}
],
"dtype": null,
"coerce": true,
"strict": false,
"name": null,
"ordered": false,
"unique": null,
"ordered": false
"report_duplicates": "all",
"unique_column_names": false,
"title": null,
"description": null
}

You can edit this json file to update the schema as needed, and then load
Expand Down
6 changes: 5 additions & 1 deletion pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,11 @@ def _get_to_datetime_fn(obj: Any) -> Callable:
)
@immutable(init=True)
class DateTime(_BaseDateTime, dtypes.Timestamp):
"""Semantic representation of a :class:`pandas.DatetimeTZDtype`."""
"""Semantic representation of a potentially timezone-aware datetime.
Uses ``np.dtype("datetime64[ns]")`` for non-timezone aware datetimes and
:class:`pandas.DatetimeTZDtype` for timezone-aware datetimes.
"""

type: Optional[_PandasDatetime] = dataclasses.field(
default=None, init=False
Expand Down
57 changes: 49 additions & 8 deletions pandera/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,22 @@
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"


def _get_qualified_name(cls: type) -> str:
return f"{cls.__module__}.{cls.__qualname__}"
def _get_dtype_string_alias(dtype: pandas_engine.DataType) -> str:
"""Get string alias of the datatype for serialization.
Calling pandas_engine.Engine.dtype(<string_alias>) should be a valid
operation
"""
str_alias = str(dtype)
try:
pandas_engine.Engine.dtype(str_alias)
except TypeError as e: # pragma: no cover
raise TypeError(
f"string alias {str_alias} for datatype "
f"'{dtype.__module__}.{dtype.__class__.__name__}' not "
"recognized."
) from e
return f'"{dtype}"'


def _serialize_check_stats(check_stats, dtype=None):
Expand Down Expand Up @@ -152,10 +166,16 @@ def serialize_schema(dataframe_schema):
"columns": columns,
"checks": checks,
"index": index,
"dtype": dataframe_schema.dtype,
"coerce": dataframe_schema.coerce,
"strict": dataframe_schema.strict,
"unique": dataframe_schema.unique,
"name": dataframe_schema.name,
"ordered": dataframe_schema.ordered,
"unique": dataframe_schema.unique,
"report_duplicates": dataframe_schema._report_duplicates,
"unique_column_names": dataframe_schema.unique_column_names,
"title": dataframe_schema.title,
"description": dataframe_schema.description,
}


Expand Down Expand Up @@ -272,10 +292,18 @@ def deserialize_schema(serialized_schema):
columns=columns,
checks=checks,
index=index,
dtype=serialized_schema.get("dtype", None),
coerce=serialized_schema.get("coerce", False),
strict=serialized_schema.get("strict", False),
unique=serialized_schema.get("unique", None),
name=serialized_schema.get("name", None),
ordered=serialized_schema.get("ordered", False),
unique=serialized_schema.get("unique", None),
report_duplicates=serialized_schema.get("_report_duplicates", "all"),
unique_column_names=serialized_schema.get(
"unique_column_names", False
),
title=serialized_schema.get("title", None),
description=serialized_schema.get("description", None),
)


Expand Down Expand Up @@ -370,10 +398,18 @@ def to_json(dataframe_schema, target=None, **kwargs):
schema = DataFrameSchema(
columns={{{columns}}},
checks={checks},
index={index},
dtype={dtype},
coerce={coerce},
strict={strict},
name={name},
ordered={ordered},
unique={unique},
report_duplicates={report_duplicates},
unique_column_names={unique_column_names},
title={title},
description={description},
)
"""

Expand Down Expand Up @@ -434,7 +470,7 @@ def _format_index(index_statistics):
description = properties.get("description")
title = properties.get("title")
index_code = INDEX_TEMPLATE.format(
dtype=f"{_get_qualified_name(dtype.__class__)}",
dtype=(None if dtype is None else _get_dtype_string_alias(dtype)),
checks=(
"None"
if properties["checks"] is None
Expand Down Expand Up @@ -479,9 +515,7 @@ def to_script(dataframe_schema, path_or_buf=None):
description = properties["description"]
title = properties["title"]
column_code = COLUMN_TEMPLATE.format(
dtype=(
None if dtype is None else _get_qualified_name(dtype.__class__)
),
dtype=(None if dtype is None else _get_dtype_string_alias(dtype)),
checks=_format_checks(properties["checks"]),
nullable=properties["nullable"],
unique=properties["unique"],
Expand All @@ -503,11 +537,18 @@ def to_script(dataframe_schema, path_or_buf=None):

script = SCRIPT_TEMPLATE.format(
columns=column_str,
checks=statistics["checks"],
index=index,
dtype=dataframe_schema.dtype,
coerce=dataframe_schema.coerce,
strict=dataframe_schema.strict,
name=dataframe_schema.name.__repr__(),
ordered=dataframe_schema.ordered,
unique=dataframe_schema.unique,
report_duplicates=f'"{dataframe_schema._report_duplicates}"',
unique_column_names=dataframe_schema.unique_column_names,
title=dataframe_schema.title,
description=dataframe_schema.description,
).strip()

# add pandas imports to handle datetime and timedelta.
Expand Down
16 changes: 14 additions & 2 deletions tests/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,16 @@ def _create_schema(index="single"):
name: null
unique: false
coerce: false
dtype: null
coerce: false
strict: true
unique: null
name: null
ordered: false
unique: null
report_duplicates: all
unique_column_names: false
title: null
description: null
"""


Expand Down Expand Up @@ -1215,10 +1221,16 @@ def datetime_check(pandas_obj, *, stat):
regex: false
checks: null
index: null
dtype: null
coerce: true
strict: true
unique: null
name: null
ordered: false
unique: null
report_duplicates: all
unique_column_names: false
title: null
description: null
"""

VALID_FRICTIONLESS_DF = pd.DataFrame(
Expand Down

0 comments on commit ba04d1b

Please sign in to comment.