Skip to content

Commit

Permalink
fix: remove Category inheritance from ArrowDictionary (#1848)
Browse files Browse the repository at this point in the history
* fix: remove Category inheritance from ArrowDictionary

Signed-off-by: Daren Liang <[email protected]>

* Fix test cases

Signed-off-by: Daren Liang <[email protected]>

---------

Signed-off-by: Daren Liang <[email protected]>
  • Loading branch information
darenliang authored Dec 3, 2024
1 parent 9667234 commit 95110a6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType):
equivalents=[pyarrow.dictionary, pyarrow.DictionaryType]
)
@immutable(init=True)
class ArrowDictionary(ArrowDataType, dtypes.Category):
class ArrowDictionary(ArrowDataType):
"""Semantic representation of a :class:`pyarrow.dictionary`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
Expand Down
2 changes: 1 addition & 1 deletion pandera/engines/pyarrow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType):
equivalents=[pyarrow.dictionary, pyarrow.DictionaryType]
)
@immutable(init=True)
class ArrowDictionary(ArrowDataType, dtypes.Category):
class ArrowDictionary(ArrowDataType):
"""Semantic representation of a :class:`pyarrow.dictionary`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(default=None, init=False)
Expand Down
61 changes: 26 additions & 35 deletions tests/core/test_pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,10 @@ def test_pandas_data_type_check(data_type_cls):
return

check_result = data_type.check(
pandas_engine.Engine.dtype(data_container.dtype),
data_container,
pandas_engine.Engine.dtype(data_container.dtype), data_container
)
assert isinstance(check_result, bool) or isinstance(
check_result.all(),
(bool, np.bool_),
check_result.all(), (bool, np.bool_)
)


Expand Down Expand Up @@ -210,7 +208,7 @@ def test_pandas_datetimetz_dtype(timezone_aware, data, timezone):
data=pd_st.series(
dtype="datetime64[ns]",
index=pd_st.range_indexes(min_size=5, max_size=10),
),
)
)
def test_pandas_date_coerce_dtype(to_df, data):
"""Test that pandas Date dtype coerces to datetime.date object."""
Expand Down Expand Up @@ -241,22 +239,11 @@ def test_pandas_date_coerce_dtype(to_df, data):


pandas_arrow_dtype_cases = (
(
pd.Series([["a", "b", "c"]]),
pyarrow.list_(pyarrow.string()),
),
(
pd.Series([["a", "b"]]),
pyarrow.list_(pyarrow.string(), 2),
),
(pd.Series([["a", "b", "c"]]), pyarrow.list_(pyarrow.string())),
(pd.Series([["a", "b"]]), pyarrow.list_(pyarrow.string(), 2)),
(
pd.Series([{"foo": 1, "bar": "a"}]),
pyarrow.struct(
[
("foo", pyarrow.int64()),
("bar", pyarrow.string()),
]
),
pyarrow.struct([("foo", pyarrow.int64()), ("bar", pyarrow.string())]),
),
(pd.Series([None, pd.NA, np.nan]), pyarrow.null),
(pd.Series([None, date(1970, 1, 1)]), pyarrow.date32),
Expand All @@ -277,6 +264,10 @@ def test_pandas_date_coerce_dtype(to_df, data):
(pd.Series(["foo", "bar", "baz", None]), pyarrow.binary(3)),
(pd.Series(["foo", "barbaz", None]), pyarrow.large_binary()),
(pd.Series(["1", "1.0", "foo", "bar", None]), pyarrow.large_string()),
(
pd.Series(["a", "b", "c"]),
pyarrow.dictionary(pyarrow.int64(), pyarrow.string()),
),
)


Expand All @@ -289,26 +280,16 @@ def test_pandas_arrow_dtype(data, dtype):
pytest.skip("Support of pandas 2.0.0+ with pyarrow only")
dtype = pandas_engine.Engine.dtype(dtype)

dtype.coerce(data)
coerced_data = dtype.coerce(data)
assert coerced_data.dtype == dtype.type


pandas_arrow_dtype_error_cases = (
(
pd.Series([["a", "b", "c"]]),
pyarrow.list_(pyarrow.int64()),
),
(
pd.Series([["a", "b"]]),
pyarrow.list_(pyarrow.string(), 3),
),
(pd.Series([["a", "b", "c"]]), pyarrow.list_(pyarrow.int64())),
(pd.Series([["a", "b"]]), pyarrow.list_(pyarrow.string(), 3)),
(
pd.Series([{"foo": 1, "bar": "a"}]),
pyarrow.struct(
[
("foo", pyarrow.string()),
("bar", pyarrow.int64()),
]
),
pyarrow.struct([("foo", pyarrow.string()), ("bar", pyarrow.int64())]),
),
(pd.Series(["a", "1"]), pyarrow.null),
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
Expand All @@ -329,6 +310,14 @@ def test_pandas_arrow_dtype(data, dtype):
(pd.Series(["foo", "bar", "baz", None]), pyarrow.binary(2)),
(pd.Series([1, "foo", "barbaz", None]), pyarrow.large_binary()),
(pd.Series([1, 1.0, "foo", "bar", None]), pyarrow.large_string()),
(
pd.Series([1.0, 2.0, 3.0]),
pyarrow.dictionary(pyarrow.int64(), pyarrow.float64()),
),
(
pd.Series(["a", "b", "c"]),
pyarrow.dictionary(pyarrow.int64(), pyarrow.int64()),
),
)


Expand All @@ -347,6 +336,8 @@ def test_pandas_arrow_dtype_error(data, dtype):
pyarrow.ArrowTypeError,
NotImplementedError,
ValueError,
AssertionError,
)
):
dtype.coerce(data)
coerced_data = dtype.coerce(data)
assert coerced_data.dtype == dtype.type

0 comments on commit 95110a6

Please sign in to comment.