diff --git a/pandera/decorators.py b/pandera/decorators.py index f4afee84..a36bb597 100644 --- a/pandera/decorators.py +++ b/pandera/decorators.py @@ -258,41 +258,29 @@ def _wrapper(*args, **kwargs): pos_args[obj_getter], *validate_args ) args = list(pos_args.values()) - elif obj_getter is None and kwargs: - # get the first key in the same order specified in the - # function argument. - args_names = _get_fn_argnames(wrapped) - - try: - kwargs[args_names[0]] = schema.validate( - kwargs[args_names[0]], *validate_args - ) - except errors.SchemaError as e: - _handle_schema_error( - "check_input", - wrapped, - schema, - kwargs[args_names[0]], - e, - ) - elif obj_getter is None and args: + elif obj_getter is None: try: - _fn = ( - wrapped - if not hasattr(wrapped, "__wrapped__") - else wrapped.__wrapped__ - ) + _fn = _unwrap_fn(wrapped) + obj_arg_name, *_ = _get_fn_argnames(wrapped) arg_spec_args = inspect.getfullargspec(_fn).args - if arg_spec_args[0] in ("self", "cls"): - arg_idx = 0 if len(args) == 1 else 1 + + arg_idx = arg_spec_args.index(obj_arg_name) + + if obj_arg_name in kwargs: + obj = kwargs[obj_arg_name] + kwargs[obj_arg_name] = schema.validate( + obj, *validate_args + ) + elif obj_arg_name in pos_args: + obj = args[arg_idx] + args[arg_idx] = schema.validate(obj, *validate_args) else: - arg_idx = 0 - args[arg_idx] = schema.validate( - args[arg_idx], *validate_args - ) + raise ValueError( + f"argument {obj_arg_name} not found in args or kwargs" + ) except errors.SchemaError as e: _handle_schema_error( - "check_input", wrapped, schema, args[0], e + "check_input", wrapped, schema, obj, e ) else: raise TypeError( diff --git a/tests/core/test_decorators.py b/tests/core/test_decorators.py index e63eccb6..c3b7360b 100644 --- a/tests/core/test_decorators.py +++ b/tests/core/test_decorators.py @@ -348,6 +348,25 @@ def _assert_expectation(result_df): ) +class DfModel(DataFrameModel): + col: int + + +# pylint: disable=unused-argument +@check_input(DfModel.to_schema()) +def fn_with_check_input(data: DataFrame[DfModel], *, kwarg: bool = False): + return data + + +def test_check_input_on_fn_with_kwarg(): + """ + That that a check_input correctly validates a function where the first arg + is the dataframe and the function has other kwargs. + """ + df = pd.DataFrame({"col": [1]}) + fn_with_check_input(df, kwarg=True) + + def test_check_io() -> None: # pylint: disable=too-many-locals """Test that check_io correctly validates/invalidates data.""" @@ -777,13 +796,13 @@ def test_check_types_with_literal_type(arg_examples): """Test that using typing module types works with check_types""" for example in arg_examples: - arg_type = Literal[example] + arg_type = Literal[example] # type: ignore @check_types def transform_with_literal( df: DataFrame[InSchema], # pylint: disable=unused-argument,cell-var-from-loop - arg: arg_type, + arg: arg_type, # type: ignore ) -> DataFrame[OutSchema]: return df.assign(b=100) # type: ignore