Skip to content

Commit

Permalink
Rework property tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jan 8, 2025
1 parent f1bd894 commit 5d75139
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
37 changes: 20 additions & 17 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@

Chunks = tuple[tuple[int, ...], ...]


def supported_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
| npst.complex_number_dtypes(endianness="=")
| npst.datetime64_dtypes(endianness="=")
| npst.timedelta64_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
)


numeric_dtypes = (
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
# TODO: add complex here not in supported_dtypes
)
# TODO: stop excluding everything but U
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cU")
by_dtype_st = supported_dtypes()
numeric_like_dtypes = (
numeric_dtypes | npst.datetime64_dtypes(endianness="=") | npst.timedelta64_dtypes(endianness="=")
)
supported_dtypes = (
numeric_like_dtypes
| npst.unicode_string_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
)
by_dtype_st = supported_dtypes

NON_NUMPY_FUNCS = [
"first",
Expand All @@ -43,12 +43,15 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:

func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_dtypes
)
numeric_like_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_like_dtypes
)
all_arrays = npst.arrays(
elements={"allow_subnormal": False},
shape=npst.array_shapes(),
dtype=supported_dtypes(),
dtype=supported_dtypes,
)

calendars = st.sampled_from(
Expand Down
13 changes: 7 additions & 6 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flox.xrutils import isnull, notnull

from . import assert_equal
from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import by_arrays, chunked_arrays, func_st, numeric_dtypes, numeric_like_arrays
from .strategies import chunks as chunks_strategy

dask.config.set(scheduler="sync")
Expand Down Expand Up @@ -66,7 +66,7 @@ def not_overflowing_array(array: np.ndarray[Any, Any]) -> bool:

@given(
data=st.data(),
array=st.one_of(numeric_arrays, chunked_arrays(arrays=numeric_arrays)),
array=st.one_of(numeric_like_arrays, chunked_arrays(arrays=numeric_like_arrays)),
func=func_st,
)
def test_groupby_reduce(data, array, func: str) -> None:
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_groupby_reduce(data, array, func: str) -> None:

@given(
data=st.data(),
array=chunked_arrays(arrays=numeric_arrays),
array=chunked_arrays(arrays=numeric_like_arrays),
func=func_st,
)
def test_groupby_reduce_numpy_vs_dask(data, array, func: str) -> None:
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_groupby_reduce_numpy_vs_dask(data, array, func: str) -> None:
@settings(report_multiple_bugs=False)
@given(
data=st.data(),
array=chunked_arrays(arrays=numeric_arrays),
array=chunked_arrays(arrays=numeric_like_arrays),
func=st.sampled_from(tuple(NUMPY_SCAN_FUNCS)),
)
def test_scans(data, array: dask.array.Array, func: str) -> None:
Expand Down Expand Up @@ -294,11 +294,12 @@ def test_first_last_useless(data, func):
assert_equal(actual, expected)


@settings(report_multiple_bugs=False)
@given(
func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
engine=st.sampled_from(["numpy", "flox"]),
array_dtype=st.none() | array_dtypes,
dtype=st.none() | array_dtypes,
array_dtype=st.none() | numeric_dtypes,
dtype=st.none() | numeric_dtypes,
)
def test_agg_dtype_specified(func, array_dtype, dtype, engine):
# regression test for GH388
Expand Down

0 comments on commit 5d75139

Please sign in to comment.