diff --git a/pandera/io/pandas_io.py b/pandera/io/pandas_io.py index 004a02411..fe17b5fd7 100644 --- a/pandera/io/pandas_io.py +++ b/pandera/io/pandas_io.py @@ -68,15 +68,28 @@ def handle_stat_dtype(stat): return stat - # for unary checks, return a single value instead of a dictionary - if len(check_stats) == 1: - return handle_stat_dtype(list(check_stats.values())[0]) + # Extract check options if they exist + check_options = ( + check_stats.pop("options", {}) if isinstance(check_stats, dict) else {} + ) + + # Handle unary checks + if isinstance(check_stats, dict) and len(check_stats) == 1: + value = handle_stat_dtype(list(check_stats.values())[0]) + if check_options: + return {"value": value, "options": check_options} + return value - # otherwise return a dictionary of keyword args needed to create the Check - serialized_check_stats = {} - for arg, stat in check_stats.items(): - serialized_check_stats[arg] = handle_stat_dtype(stat) - return serialized_check_stats + # Handle dictionary case + if isinstance(check_stats, dict): + serialized_check_stats = {} + for arg, stat in check_stats.items(): + serialized_check_stats[arg] = handle_stat_dtype(stat) + if check_options: + serialized_check_stats["options"] = check_options + return serialized_check_stats + + return handle_stat_dtype(check_stats) def _serialize_dataframe_stats(dataframe_checks): @@ -178,6 +191,8 @@ def serialize_schema(dataframe_schema): def _deserialize_check_stats(check, serialized_check_stats, dtype=None): + """Deserialize check statistics and reconstruct check with options.""" + def handle_stat_dtype(stat): try: if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype): @@ -189,15 +204,35 @@ def handle_stat_dtype(stat): return stat return stat + # Extract options if they exist + options = {} if isinstance(serialized_check_stats, dict): # handle case where serialized check stats are in the form of a # dictionary mapping Check arg names to values. + options = serialized_check_stats.pop("options", {}) + # Handle special case for unary checks with options + if ( + "value" in serialized_check_stats + and len(serialized_check_stats) == 1 + ): + serialized_check_stats = serialized_check_stats["value"] + + # Create check with original logic + if isinstance(serialized_check_stats, dict): check_stats = {} for arg, stat in serialized_check_stats.items(): check_stats[arg] = handle_stat_dtype(stat) - return check(**check_stats) - # otherwise assume unary check function signature - return check(handle_stat_dtype(serialized_check_stats)) + check_instance = check(**check_stats) + else: + # otherwise assume unary check function signature + check_instance = check(handle_stat_dtype(serialized_check_stats)) + + # Apply options if they exist + if options: + for option_name, option_value in options.items(): + setattr(check_instance, option_name, option_value) + + return check_instance def _deserialize_component_stats(serialized_component_stats): @@ -447,6 +482,7 @@ def to_json(dataframe_schema, target=None, **kwargs): def _format_checks(checks_dict): + """Format checks into string representation including options.""" if checks_dict is None: return "None" @@ -457,11 +493,33 @@ def _format_checks(checks_dict): f"Check {check_name} cannot be serialized. " "This check will be ignored" ) - else: + continue + + # Handle options separately + options = ( + check_kwargs.pop("options", {}) + if isinstance(check_kwargs, dict) + else {} + ) + + # Format main check arguments + if isinstance(check_kwargs, dict): args = ", ".join( f"{k}={v.__repr__()}" for k, v in check_kwargs.items() ) - checks.append(f"Check.{check_name}({args})") + else: + args = check_kwargs.__repr__() + + # Add options to arguments if they exist + if options: + if args: + args += ", " + args += ", ".join( + f"{k}={v.__repr__()}" for k, v in options.items() + ) + + checks.append(f"Check.{check_name}({args})") + return f"[{', '.join(checks)}]" diff --git a/pandera/schema_statistics/pandas.py b/pandera/schema_statistics/pandas.py index 00d51cf54..b2e55f188 100644 --- a/pandera/schema_statistics/pandas.py +++ b/pandera/schema_statistics/pandas.py @@ -68,14 +68,29 @@ def _index_stats(index_level): def parse_check_statistics(check_stats: Union[Dict[str, Any], None]): - """Convert check statistics to a list of Check objects.""" + """Convert check statistics to a list of Check objects, including their options.""" if check_stats is None: return None checks = [] for check_name, stats in check_stats.items(): check = getattr(Check, check_name) try: - checks.append(check(**stats)) + # Extract options if present + if isinstance(stats, dict): + options = ( + stats.pop("options", {}) if "options" in stats else {} + ) + if stats: # If there are remaining stats + check_instance = check(**stats) + else: # Handle case where all stats were in options + check_instance = check() + # Apply options to the check instance + for option_name, option_value in options.items(): + setattr(check_instance, option_name, option_value) + checks.append(check_instance) + else: + # Handle unary check case + checks.append(check(stats)) except TypeError: # if stats cannot be unpacked as key-word args, assume unary check. checks.append(check(stats)) @@ -142,9 +157,10 @@ def get_series_schema_statistics(series_schema): def parse_checks(checks) -> Union[Dict[str, Any], None]: - """Convert Check object to check statistics.""" + """Convert Check object to check statistics including options.""" check_statistics = {} _check_memo = {} + for check in checks: if check not in Check: warnings.warn( @@ -154,28 +170,46 @@ def parse_checks(checks) -> Union[Dict[str, Any], None]: ) continue - check_statistics[check.name] = ( - {} if check.statistics is None else check.statistics - ) + # Get base statistics + base_stats = {} if check.statistics is None else check.statistics + + # Collect check options + check_options = { + "raise_warning": check.raise_warning, + "n_failure_cases": check.n_failure_cases, + "ignore_na": check.ignore_na, + } + + # Filter out None values from options + check_options = { + k: v for k, v in check_options.items() if v is not None + } + + # Combine statistics with options + check_statistics[check.name] = base_stats + if check_options: + check_statistics[check.name]["options"] = check_options + _check_memo[check.name] = check - # raise ValueError on incompatible checks + # Check for incompatible checks if ( "greater_than_or_equal_to" in check_statistics and "less_than_or_equal_to" in check_statistics ): min_value = check_statistics.get( "greater_than_or_equal_to", float("-inf") - )["min_value"] + ).get("min_value", float("-inf")) max_value = check_statistics.get( "less_than_or_equal_to", float("inf") - )["max_value"] + ).get("max_value", float("inf")) if min_value > max_value: raise ValueError( f"checks {_check_memo['greater_than_or_equal_to']} " f"and {_check_memo['less_than_or_equal_to']} are incompatible, reason: " f"min value {min_value} > max value {max_value}" ) + return check_statistics if check_statistics else None diff --git a/tests/core/test_schema_statistics.py b/tests/core/test_schema_statistics.py index 010c67541..78cecfa82 100644 --- a/tests/core/test_schema_statistics.py +++ b/tests/core/test_schema_statistics.py @@ -467,8 +467,14 @@ def test_get_dataframe_schema_statistics(): "int": { "dtype": DEFAULT_INT, "checks": { - "greater_than_or_equal_to": {"min_value": 0}, - "less_than_or_equal_to": {"max_value": 100}, + "greater_than_or_equal_to": { + "min_value": 0, + "options": {"ignore_na": True, "raise_warning": False}, + }, + "less_than_or_equal_to": { + "max_value": 100, + "options": {"ignore_na": True, "raise_warning": False}, + }, }, "nullable": True, "unique": False, @@ -481,8 +487,14 @@ def test_get_dataframe_schema_statistics(): "float": { "dtype": DEFAULT_FLOAT, "checks": { - "greater_than_or_equal_to": {"min_value": 50}, - "less_than_or_equal_to": {"max_value": 100}, + "greater_than_or_equal_to": { + "min_value": 50, + "options": {"ignore_na": True, "raise_warning": False}, + }, + "less_than_or_equal_to": { + "max_value": 100, + "options": {"ignore_na": True, "raise_warning": False}, + }, }, "nullable": False, "unique": False, @@ -494,7 +506,12 @@ def test_get_dataframe_schema_statistics(): }, "str": { "dtype": pandas_engine.Engine.dtype(str), - "checks": {"isin": {"allowed_values": ["foo", "bar", "baz"]}}, + "checks": { + "isin": { + "allowed_values": ["foo", "bar", "baz"], + "options": {"ignore_na": True, "raise_warning": False}, + } + }, "nullable": False, "unique": False, "coerce": False, @@ -507,7 +524,12 @@ def test_get_dataframe_schema_statistics(): "index": [ { "dtype": DEFAULT_INT, - "checks": {"greater_than_or_equal_to": {"min_value": 0}}, + "checks": { + "greater_than_or_equal_to": { + "min_value": 0, + "options": {"ignore_na": True, "raise_warning": False}, + } + }, "nullable": False, "coerce": False, "name": "int_index", @@ -537,8 +559,14 @@ def test_get_series_schema_statistics(): "dtype": pandas_engine.Engine.dtype(int), "nullable": False, "checks": { - "greater_than_or_equal_to": {"min_value": 0}, - "less_than_or_equal_to": {"max_value": 100}, + "greater_than_or_equal_to": { + "min_value": 0, + "options": {"ignore_na": True, "raise_warning": False}, + }, + "less_than_or_equal_to": { + "max_value": 100, + "options": {"ignore_na": True, "raise_warning": False}, + }, }, "name": None, "coerce": False, @@ -566,8 +594,20 @@ def test_get_series_schema_statistics(): "dtype": pandas_engine.Engine.dtype(int), "nullable": False, "checks": { - "greater_than_or_equal_to": {"min_value": 10}, - "less_than_or_equal_to": {"max_value": 20}, + "greater_than_or_equal_to": { + "min_value": 10, + "options": { + "ignore_na": True, + "raise_warning": False, + }, + }, + "less_than_or_equal_to": { + "max_value": 20, + "options": { + "ignore_na": True, + "raise_warning": False, + }, + }, }, "name": "int_index", "coerce": False, @@ -591,7 +631,15 @@ def test_get_index_schema_statistics(index_schema_component, expectation): "checks, expectation", [ *[ - [[check], {check.name: check.statistics}] + [ + [check], + { + check.name: { + **(check.statistics or {}), + "options": {"ignore_na": True, "raise_warning": False}, + } + }, + ] for check in [ pa.Check.greater_than(1), pa.Check.less_than(1), @@ -614,9 +662,18 @@ def test_get_index_schema_statistics(index_schema_component, expectation): pa.Check.isin([10, 20, 30, 40, 50]), ], { - "greater_than_or_equal_to": {"min_value": 10}, - "less_than_or_equal_to": {"max_value": 50}, - "isin": {"allowed_values": [10, 20, 30, 40, 50]}, + "greater_than_or_equal_to": { + "min_value": 10, + "options": {"ignore_na": True, "raise_warning": False}, + }, + "less_than_or_equal_to": { + "max_value": 50, + "options": {"ignore_na": True, "raise_warning": False}, + }, + "isin": { + "allowed_values": [10, 20, 30, 40, 50], + "options": {"ignore_na": True, "raise_warning": False}, + }, }, ], # incompatible checks @@ -650,7 +707,13 @@ def test_parse_checks_and_statistics_roundtrip(checks, expectation): check_statistics = {check.name: check.statistics for check in checks} check_list = schema_statistics.parse_check_statistics(check_statistics) - assert set(check_list) == set(checks) + assert all( + c1.name == c2.name and c1.statistics == c2.statistics + for c1, c2 in zip( + sorted(checks, key=lambda x: x.name), + sorted(check_list, key=lambda x: x.name), + ) + ) # pylint: disable=unused-argument @@ -661,12 +724,20 @@ def test_parse_checks_and_statistics_no_param(extra_registered_checks): """ checks = [pa.Check.no_param_check()] - expectation = {"no_param_check": {}} + expectation = { + "no_param_check": { + "options": {"ignore_na": True, "raise_warning": False} + } + } assert schema_statistics.parse_checks(checks) == expectation check_statistics = {check.name: check.statistics for check in checks} check_list = schema_statistics.parse_check_statistics(check_statistics) - assert set(check_list) == set(checks) - -# pylint: enable=unused-argument + assert all( + c1.name == c2.name and c1.statistics == c2.statistics + for c1, c2 in zip( + sorted(checks, key=lambda x: x.name), + sorted(check_list, key=lambda x: x.name), + ) + ) diff --git a/tests/io/test_io.py b/tests/io/test_io.py index c87302fc6..d87c2de8b 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -127,13 +127,24 @@ def _create_schema(index="single"): dtype: int64 nullable: false checks: - greater_than: 0 - less_than: 10 + greater_than: + value: 0 + options: + raise_warning: false + ignore_na: true + less_than: + value: 10 + options: + raise_warning: false + ignore_na: true in_range: min_value: 0 max_value: 10 include_min: true include_max: true + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -144,13 +155,24 @@ def _create_schema(index="single"): dtype: float64 nullable: false checks: - greater_than: -10 - less_than: 20 + greater_than: + value: -10 + options: + raise_warning: false + ignore_na: true + less_than: + value: 20 + options: + raise_warning: false + ignore_na: true in_range: min_value: -10 max_value: 20 include_min: true include_max: true + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -162,13 +184,20 @@ def _create_schema(index="single"): nullable: false checks: isin: - - foo - - bar - - x - - xy + value: + - foo + - bar + - x + - xy + options: + raise_warning: false + ignore_na: true str_length: min_value: 1 max_value: 3 + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -179,8 +208,16 @@ def _create_schema(index="single"): dtype: datetime64[ns] nullable: false checks: - greater_than: '2010-01-01 00:00:00' - less_than: '2020-01-01 00:00:00' + greater_than: + value: '2010-01-01 00:00:00' + options: + raise_warning: false + ignore_na: true + less_than: + value: '2020-01-01 00:00:00' + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -191,8 +228,16 @@ def _create_schema(index="single"): dtype: timedelta64[ns] nullable: false checks: - greater_than: 1000 - less_than: 10000 + greater_than: + value: 1000 + options: + raise_warning: false + ignore_na: true + less_than: + value: 10000 + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -206,6 +251,9 @@ def _create_schema(index="single"): str_length: min_value: 1 max_value: 3 + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: false @@ -217,10 +265,14 @@ def _create_schema(index="single"): nullable: false checks: isin: - - foo - - bar - - x - - xy + value: + - foo + - bar + - x + - xy + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -290,13 +342,20 @@ def _create_schema_null_index(): nullable: false checks: isin: - - foo - - bar - - x - - xy + value: + - foo + - bar + - x + - xy + options: + raise_warning: false + ignore_na: true str_length: min_value: 1 max_value: 3 + options: + raise_warning: false + ignore_na: true index: null checks: null coerce: false @@ -388,13 +447,24 @@ def _create_schema_python_types(): dtype: int64 nullable: false checks: - greater_than: 0 - less_than: 10 + greater_than: + value: 0 + options: + raise_warning: false + ignore_na: true + less_than: + value: 10 + options: + raise_warning: false + ignore_na: true in_range: min_value: 0 max_value: 10 include_min: true include_max: true + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -405,13 +475,24 @@ def _create_schema_python_types(): dtype: float64 nullable: false checks: - greater_than: -10 - less_than: 20 + greater_than: + value: -10 + options: + raise_warning: false + ignore_na: true + less_than: + value: 20 + options: + raise_warning: false + ignore_na: true in_range: min_value: -10 max_value: 20 include_min: true include_max: true + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -423,13 +504,20 @@ def _create_schema_python_types(): nullable: false checks: isin: - - foo - - bar - - x - - xy + value: + - foo + - bar + - x + - xy + options: + raise_warning: false + ignore_na: true str_length: min_value: 1 max_value: 3 + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -440,8 +528,16 @@ def _create_schema_python_types(): dtype: datetime64[ns] nullable: false checks: - greater_than: '2010-01-01 00:00:00' - less_than: '2020-01-01 00:00:00' + greater_than: + value: '2010-01-01 00:00:00' + options: + raise_warning: false + ignore_na: true + less_than: + value: '2020-01-01 00:00:00' + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -452,8 +548,16 @@ def _create_schema_python_types(): dtype: timedelta64[ns] nullable: false checks: - greater_than: 1000 - less_than: 10000 + greater_than: + value: 1000 + options: + raise_warning: false + ignore_na: true + less_than: + value: 10000 + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -467,6 +571,9 @@ def _create_schema_python_types(): str_length: min_value: 1 max_value: 3 + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: false @@ -478,10 +585,14 @@ def _create_schema_python_types(): nullable: false checks: isin: - - foo - - bar - - x - - xy + value: + - foo + - bar + - x + - xy + options: + raise_warning: false + ignore_na: true unique: false coerce: false required: true @@ -1138,6 +1249,9 @@ def datetime_check(pandas_obj, *, stat): ... max_value: 99 include_min: true include_max: true + options: + raise_warning: false + ignore_na: true unique: true coerce: true required: true @@ -1148,7 +1262,11 @@ def datetime_check(pandas_obj, *, stat): ... dtype: {INT_DTYPE} nullable: true checks: - less_than_or_equal_to: 30 + less_than_or_equal_to: + value: 30 + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: true @@ -1162,6 +1280,9 @@ def datetime_check(pandas_obj, *, stat): ... str_length: min_value: 3 max_value: 80 + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: true @@ -1172,7 +1293,11 @@ def datetime_check(pandas_obj, *, stat): ... dtype: {STR_DTYPE} nullable: true checks: - str_matches: ^\\d{{3}}[A-Z]$ + str_matches: + value: ^\\d{{3}}[A-Z]$ + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: true @@ -1186,6 +1311,9 @@ def datetime_check(pandas_obj, *, stat): ... str_length: min_value: 3 max_value: null + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: true @@ -1199,6 +1327,9 @@ def datetime_check(pandas_obj, *, stat): ... str_length: min_value: null max_value: 3 + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: true @@ -1210,9 +1341,13 @@ def datetime_check(pandas_obj, *, stat): ... nullable: false checks: isin: - - 1.0 - - 2.0 - - 3.0 + value: + - 1.0 + - 2.0 + - 3.0 + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: true @@ -1233,7 +1368,11 @@ def datetime_check(pandas_obj, *, stat): ... dtype: {STR_DTYPE} nullable: true checks: - greater_than_or_equal_to: '20201231' + greater_than_or_equal_to: + value: '20201231' + options: + raise_warning: false + ignore_na: true unique: false coerce: true required: true