From ba04d1be01803c6e8ca813ab45dc3b228b028057 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Tue, 1 Nov 2022 15:15:29 -0400 Subject: [PATCH] fix to_script output (#994) * fix to_script output use string alias representation when writing schame to a script. string aliases are more portable and don't rely on specifics of the pandas API, especially for parameterized datatypes. Signed-off-by: Niels Bantilan * include other DataFrameSchema attributes in serialization Signed-off-by: Niels Bantilan Signed-off-by: Niels Bantilan --- docs/source/schema_inference.rst | 32 ++++++++++++++---- pandera/engines/pandas_engine.py | 6 +++- pandera/io.py | 57 +++++++++++++++++++++++++++----- tests/io/test_io.py | 16 +++++++-- 4 files changed, 94 insertions(+), 17 deletions(-) diff --git a/docs/source/schema_inference.rst b/docs/source/schema_inference.rst index 41a37b50a..ba7f3c9fa 100644 --- a/docs/source/schema_inference.rst +++ b/docs/source/schema_inference.rst @@ -102,7 +102,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr schema = DataFrameSchema( columns={ "column1": Column( - dtype=pandera.engines.numpy_engine.Int64, + dtype="int64", checks=[ Check.greater_than_or_equal_to(min_value=5.0), Check.less_than_or_equal_to(max_value=20.0), @@ -116,7 +116,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr title=None, ), "column2": Column( - dtype=pandera.engines.numpy_engine.Object, + dtype="object", checks=None, nullable=False, unique=False, @@ -127,7 +127,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr title=None, ), "column3": Column( - dtype=pandera.engines.pandas_engine.DateTime, + dtype="datetime64[ns]", checks=[ Check.greater_than_or_equal_to( min_value=Timestamp("2010-01-01 00:00:00") @@ -145,8 +145,9 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr title=None, ), }, + checks=None, index=Index( - dtype=pandera.engines.numpy_engine.Int64, + dtype="int64", checks=[ Check.greater_than_or_equal_to(min_value=0.0), Check.less_than_or_equal_to(max_value=2.0), @@ -157,9 +158,16 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr description=None, title=None, ), + dtype=None, coerce=True, strict=False, name=None, + ordered=False, + unique=None, + report_duplicates="all", + unique_column_names=False, + title=None, + description=None, ) As a python script, you can iterate on an inferred schema and use it to @@ -234,10 +242,16 @@ is a convenience method for this functionality. name: null unique: false coerce: false + dtype: null coerce: true strict: false - unique: null + name: null ordered: false + unique: null + report_duplicates: all + unique_column_names: false + title: null + description: null You can edit this yaml file to modify the schema. For example, you can specify new column names under the ``column`` key, and the respective values map onto @@ -328,10 +342,16 @@ is a convenience method for this functionality. "coerce": false } ], + "dtype": null, "coerce": true, "strict": false, + "name": null, + "ordered": false, "unique": null, - "ordered": false + "report_duplicates": "all", + "unique_column_names": false, + "title": null, + "description": null } You can edit this json file to update the schema as needed, and then load diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py index 5e2600ae9..530d6bfa2 100644 --- a/pandera/engines/pandas_engine.py +++ b/pandera/engines/pandas_engine.py @@ -745,7 +745,11 @@ def _get_to_datetime_fn(obj: Any) -> Callable: ) @immutable(init=True) class DateTime(_BaseDateTime, dtypes.Timestamp): - """Semantic representation of a :class:`pandas.DatetimeTZDtype`.""" + """Semantic representation of a potentially timezone-aware datetime. + + Uses ``np.dtype("datetime64[ns]")`` for non-timezone aware datetimes and + :class:`pandas.DatetimeTZDtype` for timezone-aware datetimes. + """ type: Optional[_PandasDatetime] = dataclasses.field( default=None, init=False diff --git a/pandera/io.py b/pandera/io.py index a83a5e63d..10fb8989a 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -34,8 +34,22 @@ DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" -def _get_qualified_name(cls: type) -> str: - return f"{cls.__module__}.{cls.__qualname__}" +def _get_dtype_string_alias(dtype: pandas_engine.DataType) -> str: + """Get string alias of the datatype for serialization. + + Calling pandas_engine.Engine.dtype() should be a valid + operation + """ + str_alias = str(dtype) + try: + pandas_engine.Engine.dtype(str_alias) + except TypeError as e: # pragma: no cover + raise TypeError( + f"string alias {str_alias} for datatype " + f"'{dtype.__module__}.{dtype.__class__.__name__}' not " + "recognized." + ) from e + return f'"{dtype}"' def _serialize_check_stats(check_stats, dtype=None): @@ -152,10 +166,16 @@ def serialize_schema(dataframe_schema): "columns": columns, "checks": checks, "index": index, + "dtype": dataframe_schema.dtype, "coerce": dataframe_schema.coerce, "strict": dataframe_schema.strict, - "unique": dataframe_schema.unique, + "name": dataframe_schema.name, "ordered": dataframe_schema.ordered, + "unique": dataframe_schema.unique, + "report_duplicates": dataframe_schema._report_duplicates, + "unique_column_names": dataframe_schema.unique_column_names, + "title": dataframe_schema.title, + "description": dataframe_schema.description, } @@ -272,10 +292,18 @@ def deserialize_schema(serialized_schema): columns=columns, checks=checks, index=index, + dtype=serialized_schema.get("dtype", None), coerce=serialized_schema.get("coerce", False), strict=serialized_schema.get("strict", False), - unique=serialized_schema.get("unique", None), + name=serialized_schema.get("name", None), ordered=serialized_schema.get("ordered", False), + unique=serialized_schema.get("unique", None), + report_duplicates=serialized_schema.get("_report_duplicates", "all"), + unique_column_names=serialized_schema.get( + "unique_column_names", False + ), + title=serialized_schema.get("title", None), + description=serialized_schema.get("description", None), ) @@ -370,10 +398,18 @@ def to_json(dataframe_schema, target=None, **kwargs): schema = DataFrameSchema( columns={{{columns}}}, + checks={checks}, index={index}, + dtype={dtype}, coerce={coerce}, strict={strict}, name={name}, + ordered={ordered}, + unique={unique}, + report_duplicates={report_duplicates}, + unique_column_names={unique_column_names}, + title={title}, + description={description}, ) """ @@ -434,7 +470,7 @@ def _format_index(index_statistics): description = properties.get("description") title = properties.get("title") index_code = INDEX_TEMPLATE.format( - dtype=f"{_get_qualified_name(dtype.__class__)}", + dtype=(None if dtype is None else _get_dtype_string_alias(dtype)), checks=( "None" if properties["checks"] is None @@ -479,9 +515,7 @@ def to_script(dataframe_schema, path_or_buf=None): description = properties["description"] title = properties["title"] column_code = COLUMN_TEMPLATE.format( - dtype=( - None if dtype is None else _get_qualified_name(dtype.__class__) - ), + dtype=(None if dtype is None else _get_dtype_string_alias(dtype)), checks=_format_checks(properties["checks"]), nullable=properties["nullable"], unique=properties["unique"], @@ -503,11 +537,18 @@ def to_script(dataframe_schema, path_or_buf=None): script = SCRIPT_TEMPLATE.format( columns=column_str, + checks=statistics["checks"], index=index, + dtype=dataframe_schema.dtype, coerce=dataframe_schema.coerce, strict=dataframe_schema.strict, name=dataframe_schema.name.__repr__(), + ordered=dataframe_schema.ordered, unique=dataframe_schema.unique, + report_duplicates=f'"{dataframe_schema._report_duplicates}"', + unique_column_names=dataframe_schema.unique_column_names, + title=dataframe_schema.title, + description=dataframe_schema.description, ).strip() # add pandas imports to handle datetime and timedelta. diff --git a/tests/io/test_io.py b/tests/io/test_io.py index 5ca633f92..adadc482c 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -235,10 +235,16 @@ def _create_schema(index="single"): name: null unique: false coerce: false +dtype: null coerce: false strict: true -unique: null +name: null ordered: false +unique: null +report_duplicates: all +unique_column_names: false +title: null +description: null """ @@ -1215,10 +1221,16 @@ def datetime_check(pandas_obj, *, stat): regex: false checks: null index: null +dtype: null coerce: true strict: true -unique: null +name: null ordered: false +unique: null +report_duplicates: all +unique_column_names: false +title: null +description: null """ VALID_FRICTIONLESS_DF = pd.DataFrame(