From 3000febba77f7bb340f248c95c7eca53a736c812 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sun, 28 Apr 2024 09:10:20 +0100 Subject: [PATCH] Add a `groupby_blockwise` function for use in Flox (#448) --- cubed/core/groupby.py | 144 +++++++++++++++++++++++++++++++++++- cubed/tests/test_groupby.py | 93 ++++++++++++++++++++++- 2 files changed, 234 insertions(+), 3 deletions(-) diff --git a/cubed/core/groupby.py b/cubed/core/groupby.py index 726906d2..fece0853 100644 --- a/cubed/core/groupby.py +++ b/cubed/core/groupby.py @@ -2,7 +2,9 @@ from cubed.array_api.manipulation_functions import broadcast_to, expand_dims from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import map_blocks, reduction_new +from cubed.core.ops import map_blocks, map_direct, reduction_new +from cubed.utils import array_memory, get_item +from cubed.vendor.dask.array.core import normalize_chunks if TYPE_CHECKING: from cubed.array_api.array_object import Array @@ -22,7 +24,7 @@ def groupby_reduction( num_groups=None, extra_func_kwargs=None, ) -> "Array": - """A reduction that performs groupby aggregations. + """A reduction operation that performs groupby aggregations. Parameters ---------- @@ -116,3 +118,141 @@ def wrapper(a, by, **kwargs): combine_sizes={axis: num_groups}, # group axis doesn't have size 1 extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis), ) + + +def groupby_blockwise( + x: "Array", + by, + func, + axis=None, + dtype=None, + num_groups=None, + extra_func_kwargs=None, +): + """A blockwise operation that performs groupby aggregations. + + Parameters + ---------- + x: Array + Array being grouped along one axis. + by: nxp.array + Array of non-negative integers to be used as labels with which to group + the values in ``x`` along the reduction axis. Must be a 1D array. + func: callable + Function to apply to each chunk of data. The output of the + function is a chunk with size corresponding to the number of groups in the + input chunk along the reduction axis. + axis: int or sequence of ints, optional + Axis to aggregate along. Only supports a single axis. + dtype: dtype + Data type of output. + num_groups: int + The number of groups in the grouping array ``by``. + extra_func_kwargs: dict, optional + Extra keyword arguments to pass to ``func``. + """ + + if by.ndim != 1: + raise ValueError(f"Array `by` must be 1D, but has {by.ndim} dimensions.") + + if isinstance(axis, tuple): + if len(axis) != 1: + raise ValueError( + f"Only a single axis is supported for groupby_reduction: {axis}" + ) + axis = axis[0] + + newchunks, groups_per_chunk = _get_chunks_for_groups( + x.numblocks[axis], + by, + num_groups=num_groups, + ) + + # calculate the chunking used to read the input array 'x' + read_chunks = tuple(newchunks if i == axis else c for i, c in enumerate(x.chunks)) + + # 'by' is not a cubed array, but we still read it in chunks + by_read_chunks = (newchunks,) + + # find shape and chunks for the output + shape = tuple(num_groups if i == axis else s for i, s in enumerate(x.shape)) + chunks = tuple( + groups_per_chunk if i == axis else c for i, c in enumerate(x.chunksize) + ) + target_chunks = normalize_chunks(chunks, shape, dtype=dtype) + + # memory allocated by reading one chunk from input array + # note that although read_chunks will overlap multiple input chunks, zarr will + # read the chunks in series, reusing the buffer + extra_projected_mem = x.chunkmem + + # memory allocated for largest of (variable sized) read_chunks + read_chunksize = tuple(max(c) for c in read_chunks) + extra_projected_mem += array_memory(x.dtype, read_chunksize) + + return map_direct( + _process_blockwise_chunk, + x, + shape=shape, + dtype=dtype, + chunks=target_chunks, + extra_projected_mem=extra_projected_mem, + axis=axis, + by=by, + blockwise_func=func, + read_chunks=read_chunks, + by_read_chunks=by_read_chunks, + target_chunks=target_chunks, + groups_per_chunk=groups_per_chunk, + extra_func_kwargs=extra_func_kwargs, + ) + + +def _process_blockwise_chunk( + x, + *arrays, + axis=None, + by=None, + blockwise_func=None, + read_chunks=None, + by_read_chunks=None, + target_chunks=None, + groups_per_chunk=None, + block_id=None, + **kwargs, +): + array = arrays[0].zarray # underlying Zarr array (or virtual array) + idx = block_id + bi = idx[axis] + + result = array[get_item(read_chunks, idx)] + by = by[get_item(by_read_chunks, (bi,))] + + start_group = bi * groups_per_chunk + + return blockwise_func( + result, + by, + axis=axis, + start_group=start_group, + num_groups=target_chunks[axis][bi], + **kwargs, + ) + + +def _get_chunks_for_groups(num_chunks, labels, num_groups): + """Find new chunking so that there are an equal number of group labels per chunk.""" + + # find the start indexes of each group + start_indexes = nxp.searchsorted(labels, nxp.arange(num_groups)) + + # find the number of groups per chunk + groups_per_chunk = max(num_groups // num_chunks, 1) + + # each chunk has groups_per_chunk groups in it (except possibly last one) + chunk_boundaries = start_indexes[::groups_per_chunk] + + # successive differences give the new chunk sizes (include end index for last chunk) + newchunks = nxp.diff(chunk_boundaries, append=len(labels)) + + return tuple(newchunks), groups_per_chunk diff --git a/cubed/tests/test_groupby.py b/cubed/tests/test_groupby.py index 16e2b1a4..e14012db 100644 --- a/cubed/tests/test_groupby.py +++ b/cubed/tests/test_groupby.py @@ -1,10 +1,15 @@ import numpy as np import numpy_groupies as npg +import pytest from numpy.testing import assert_array_equal import cubed.array_api as xp from cubed.backend_array_api import namespace as nxp -from cubed.core.groupby import groupby_reduction +from cubed.core.groupby import ( + _get_chunks_for_groups, + groupby_blockwise, + groupby_reduction, +) def test_groupby_reduction_axis0(): @@ -59,3 +64,89 @@ def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims): def _mean_groupby_aggregate(a): return nxp.divide(a["total"], a["n"]) + + +@pytest.mark.parametrize( + "num_chunks, expected_newchunks, expected_groups_per_chunk", + [ + [10, (3, 2, 2, 0, 3), 1], + [5, (3, 2, 2, 0, 3), 1], + [4, (3, 2, 2, 0, 3), 1], + [3, (3, 2, 2, 0, 3), 1], + [2, (5, 2, 3), 2], + [2, (5, 2, 3), 2], + [2, (5, 2, 3), 2], + [2, (5, 2, 3), 2], + [2, (5, 2, 3), 2], + [1, (10), 5], + ], +) +def test_get_chunks_for_groups( + num_chunks, expected_newchunks, expected_groups_per_chunk +): + # group 3 has no data + labels = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4]) + newchunks, groups_per_chunk = _get_chunks_for_groups( + num_chunks, labels, num_groups=5 + ) + assert_array_equal(newchunks, expected_newchunks) + assert groups_per_chunk == expected_groups_per_chunk + + +def test_groupby_blockwise_axis0(): + a = xp.ones((10, 3), dtype=nxp.int32, chunks=(6, 2)) + b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4]) + extra_func_kwargs = dict(dtype=nxp.int32) + c = groupby_blockwise( + a, + b, + func=_sum_reduction_func, + axis=0, + dtype=nxp.int64, + num_groups=6, + extra_func_kwargs=extra_func_kwargs, + ) + assert_array_equal( + c.compute(), + nxp.asarray( + [ + [3, 3, 3], + [2, 2, 2], + [2, 2, 2], + [0, 0, 0], # group 3 has no data + [3, 3, 3], + [0, 0, 0], # final group since we specified num_groups=6 + ] + ), + ) + + +def test_groupby_blockwise_axis1(): + a = xp.ones((3, 10), dtype=nxp.int32, chunks=(6, 2)) + b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4]) + extra_func_kwargs = dict(dtype=nxp.int32) + c = groupby_blockwise( + a, + b, + func=_sum_reduction_func, + axis=1, + dtype=nxp.int64, + num_groups=6, + extra_func_kwargs=extra_func_kwargs, + ) + assert_array_equal( + c.compute(), + nxp.asarray( + [ + [3, 2, 2, 0, 3, 0], + [3, 2, 2, 0, 3, 0], + [3, 2, 2, 0, 3, 0], + ] + ), + ) + + +def _sum_reduction_func(arr, by, axis, start_group, num_groups, dtype): + # change 'by' so it starts from 0 for each chunk + by = by - start_group + return npg.aggregate(by, arr, func="sum", dtype=dtype, axis=axis, size=num_groups)