Skip to content

Commit

Permalink
More support for datetime, timedelta (#412)
Browse files Browse the repository at this point in the history
* More support for datetime, timedelta

Closes #403

* Rework property tests

* Add cftime property tests

* add cftime unit test

* Smarter bool conversion

* typing

* more typing

* cubed bugfix

* xfail one more

* xfail nanprod too
  • Loading branch information
dcherian authored Jan 13, 2025
1 parent 0344a28 commit df81a8d
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 92 deletions.
4 changes: 3 additions & 1 deletion flox/aggregate_numbagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"nanmean": {np.int_: np.float64},
"nanvar": {np.int_: np.float64},
"nanstd": {np.int_: np.float64},
"nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64},
"nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64},
}


Expand All @@ -51,7 +53,7 @@ def _numbagg_wrapper(
if cast_to:
for from_, to_ in cast_to.items():
if np.issubdtype(array.dtype, from_):
array = array.astype(to_)
array = array.astype(to_, copy=False)

func_ = getattr(numbagg.grouped, f"group_{func}")

Expand Down
57 changes: 54 additions & 3 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
)
from .cache import memoize
from .xrutils import (
_contains_cftime_datetimes,
_to_pytimedelta,
datetime_to_numeric,
is_chunked_array,
is_duck_array,
is_duck_cubed_array,
Expand Down Expand Up @@ -171,6 +174,17 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
return func in ["nanfirst", "nanlast", "first", "last"]


def _is_bool_supported_reduction(func: T_Agg) -> bool:
if isinstance(func, Aggregation):
func = func.name
return (
func in ["all", "any"]
# TODO: enable in npg
# or _is_first_last_reduction(func)
# or _is_minmax_reduction(func)
)


def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
if is_duck_dask_array(by):
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
Expand Down Expand Up @@ -2422,7 +2436,7 @@ def groupby_reduce(
array.dtype,
)

is_bool_array = np.issubdtype(array.dtype, bool)
is_bool_array = np.issubdtype(array.dtype, bool) and not _is_bool_supported_reduction(func)
array = array.astype(np.int_) if is_bool_array else array

isbins = _atleast_1d(isbin, nby)
Expand Down Expand Up @@ -2472,7 +2486,8 @@ def groupby_reduce(
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)

if _is_first_last_reduction(func):
is_first_last = _is_first_last_reduction(func)
if is_first_last:
if has_dask and nax != 1:
raise ValueError(
"For dask arrays: first, last, nanfirst, nanlast reductions are "
Expand All @@ -2485,6 +2500,22 @@ def groupby_reduce(
"along a single axis or when reducing across all dimensions of `by`."
)

is_npdatetime = array.dtype.kind in "Mm"
is_cftime = _contains_cftime_datetimes(array)
requires_numeric = (
(func not in ["count", "any", "all"] and not is_first_last)
# Flox's count works with non-numeric and its faster than converting.
or (func == "count" and engine != "flox")
or (is_first_last and is_cftime)
)
if requires_numeric:
if is_npdatetime:
datetime_dtype = array.dtype
array = array.view(np.int64)
elif is_cftime:
offset = array.min()
array = datetime_to_numeric(array, offset, datetime_unit="us")

if nax == 1 and by_.ndim > 1 and expected_ is None:
# When we reduce along all axes, we are guaranteed to see all
# groups in the final combine stage, so everything works.
Expand Down Expand Up @@ -2670,6 +2701,14 @@ def groupby_reduce(

if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
result = result.astype(bool)

# Output of count has an int dtype.
if requires_numeric and func != "count":
if is_npdatetime:
result = result.astype(datetime_dtype)
elif is_cftime:
result = _to_pytimedelta(result, unit="us") + offset

return (result, *groups)


Expand Down Expand Up @@ -2810,6 +2849,12 @@ def groupby_scan(
(by_,) = bys
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)

if array.dtype.kind in "Mm":
cast_to = array.dtype
array = array.view(np.int64)
else:
cast_to = None

# TODO: move to aggregate_npg.py
if agg.name in ["cumsum", "nancumsum"] and array.dtype.kind in ["i", "u"]:
# https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
Expand All @@ -2825,7 +2870,10 @@ def groupby_scan(
(single_axis,) = axis_ # type: ignore[misc]
# avoid some roundoff error when we can.
if by_.shape[-1] == 1 or by_.shape == grp_shape:
return array.astype(agg.dtype)
array = array.astype(agg.dtype)
if cast_to is not None:
array = array.astype(cast_to)
return array

# Made a design choice here to have `preprocess` handle both array and group_idx
# Example: for reversing, we need to reverse the whole array, not just reverse
Expand All @@ -2844,6 +2892,9 @@ def groupby_scan(
out = AlignedArrays(array=result, group_idx=by_)
if agg.finalize:
out = agg.finalize(out)

if cast_to is not None:
return out.array.astype(cast_to)
return out.array


Expand Down
25 changes: 0 additions & 25 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pandas as pd
import xarray as xr
from packaging.version import Version
from xarray.core.duck_array_ops import _datetime_nanmin

from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
from .core import (
Expand All @@ -18,7 +17,6 @@
)
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric

if TYPE_CHECKING:
from xarray.core.types import T_DataArray, T_Dataset
Expand Down Expand Up @@ -366,22 +364,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
if "nan" not in func and func not in ["all", "any", "count"]:
func = f"nan{func}"

# Flox's count works with non-numeric and its faster than converting.
requires_numeric = func not in ["count", "any", "all"] or (
func == "count" and kwargs["engine"] != "flox"
)
if requires_numeric:
is_npdatetime = array.dtype.kind in "Mm"
is_cftime = _contains_cftime_datetimes(array)
if is_npdatetime:
offset = _datetime_nanmin(array)
# xarray always uses np.datetime64[ns] for np.datetime64 data
dtype = "timedelta64[ns]"
array = datetime_to_numeric(array, offset)
elif is_cftime:
offset = array.min()
array = datetime_to_numeric(array, offset, datetime_unit="us")

result, *groups = groupby_reduce(array, *by, func=func, **kwargs)

# Transpose the new quantile dimension to the end. This is ugly.
Expand All @@ -395,13 +377,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
result = np.moveaxis(result, 0, -1)

# Output of count has an int dtype.
if requires_numeric and func != "count":
if is_npdatetime:
return result.astype(dtype) + offset
elif is_cftime:
return _to_pytimedelta(result, unit="us") + offset

return result

# These data variables do not have any of the core dimension,
Expand Down
24 changes: 22 additions & 2 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
"""
# TODO: make this function dask-compatible?
# Set offset to minimum if not given
from xarray.core.duck_array_ops import _datetime_nanmin

if offset is None:
if array.dtype.kind in "Mm":
offset = _datetime_nanmin(array)
Expand Down Expand Up @@ -345,6 +343,28 @@ def _contains_cftime_datetimes(array) -> bool:
return False


def _datetime_nanmin(array):
"""nanmin() function for datetime64.
Caveats that this function deals with:
- In numpy < 1.18, min() on datetime64 incorrectly ignores NaT
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
- dask min() does not work on datetime64 (all versions at the moment of writing)
"""
from .xrdtypes import is_datetime_like

dtype = array.dtype
assert is_datetime_like(dtype)
# (NaT).astype(float) does not produce NaN...
array = np.where(pd.isnull(array), np.nan, array.astype(float))
array = np.nanmin(array)
if isinstance(array, float):
array = np.array(array)
# ...but (NaN).astype("M8") does produce NaT
return array.astype(dtype)


def _select_along_axis(values, idx, axis):
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
sl = other_ind[:axis] + (idx,) + other_ind[axis:]
Expand Down
89 changes: 49 additions & 40 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,6 @@

Chunks = tuple[tuple[int, ...], ...]


def supported_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
| npst.complex_number_dtypes(endianness="=")
| npst.datetime64_dtypes(endianness="=")
| npst.timedelta64_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
)


# TODO: stop excluding everything but U
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
by_dtype_st = supported_dtypes()

NON_NUMPY_FUNCS = [
"first",
"last",
"nanfirst",
"nanlast",
"count",
"any",
"all",
] + list(SCIPY_STATS_FUNCS)
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]

func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
)
all_arrays = npst.arrays(
elements={"allow_subnormal": False},
shape=npst.array_shapes(),
dtype=supported_dtypes(),
)

calendars = st.sampled_from(
[
"standard",
Expand Down Expand Up @@ -89,7 +51,7 @@ def units(draw, *, calendar: str) -> str:
def cftime_arrays(
draw: st.DrawFn,
*,
shape: tuple[int, ...],
shape: st.SearchStrategy[tuple[int, ...]] = npst.array_shapes(),
calendars: st.SearchStrategy[str] = calendars,
elements: dict[str, Any] | None = None,
) -> np.ndarray[Any, Any]:
Expand All @@ -103,8 +65,55 @@ def cftime_arrays(
return cftime.num2date(values, units=unit, calendar=cal)


numeric_dtypes = (
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
# TODO: add complex here not in supported_dtypes
)
numeric_like_dtypes = (
npst.boolean_dtypes()
| numeric_dtypes
| npst.datetime64_dtypes(endianness="=")
| npst.timedelta64_dtypes(endianness="=")
)
supported_dtypes = (
numeric_like_dtypes
| npst.unicode_string_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
)
by_dtype_st = supported_dtypes

NON_NUMPY_FUNCS = [
"first",
"last",
"nanfirst",
"nanlast",
"count",
"any",
"all",
] + list(SCIPY_STATS_FUNCS)
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]

func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_dtypes
)
numeric_like_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_like_dtypes
)
all_arrays = (
npst.arrays(
elements={"allow_subnormal": False},
shape=npst.array_shapes(),
dtype=numeric_like_dtypes,
)
| cftime_arrays()
)


def by_arrays(
shape: tuple[int, ...], *, elements: dict[str, Any] | None = None
shape: st.SearchStrategy[tuple[int, ...]], *, elements: dict[str, Any] | None = None
) -> st.SearchStrategy[np.ndarray[Any, Any]]:
if elements is None:
elements = {}
Expand Down
29 changes: 29 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,3 +2007,32 @@ def test_blockwise_avoid_rechunk():
actual, groups = groupby_reduce(array, by, func="first")
assert_equal(groups, ["", "0", "1"])
assert_equal(actual, np.array([0, 0, 0], dtype=np.int64))


def test_datetime_minmax(engine):
# GH403
array = np.array([np.datetime64("2000-01-01"), np.datetime64("2000-01-02"), np.datetime64("2000-01-03")])
by = np.array([0, 0, 1])
actual, _ = flox.groupby_reduce(array, by, func="nanmin", engine=engine)
expected = array[[0, 2]]
assert_equal(expected, actual)

expected = array[[1, 2]]
actual, _ = flox.groupby_reduce(array, by, func="nanmax", engine=engine)
assert_equal(expected, actual)


@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
def test_datetime_timedelta_first_last(engine, func):
import flox

idx = 0 if "first" in func else -1

dt = pd.date_range("2001-01-01", freq="d", periods=5).values
by = np.ones(dt.shape, dtype=int)
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
assert_equal(actual, dt[[idx]])

dt = dt - dt[0]
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
assert_equal(actual, dt[[idx]])
Loading

0 comments on commit df81a8d

Please sign in to comment.