Skip to content

Commit

Permalink
Allow dtype to be passed to blockwise and reduction functions (#321)
Browse files Browse the repository at this point in the history
* Allow dtype to be passed to blockwise function

* Fix sum and prod dtypes

* Fix apply_gufunc test

* Fix mean intermediate dtypes
  • Loading branch information
tomwhite authored Nov 16, 2023
1 parent 914c542 commit d19caec
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 8 deletions.
32 changes: 26 additions & 6 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def mean(x, /, *, axis=None, keepdims=False):
# outputs.
dtype = x.dtype
intermediate_dtype = [("n", np.int64), ("total", np.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype)
return reduction(
x,
_mean_func,
Expand All @@ -44,18 +45,21 @@ def mean(x, /, *, axis=None, keepdims=False):
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
extra_func_kwargs=extra_func_kwargs,
)


def _mean_func(a, **kwargs):
n = _numel(a, **kwargs)
total = np.sum(a, **kwargs)
dtype = dict(kwargs.pop("dtype"))
n = _numel(a, dtype=dtype["n"], **kwargs)
total = np.sum(a, dtype=dtype["total"], **kwargs)
return {"n": n, "total": total}


def _mean_combine(a, **kwargs):
n = np.sum(a["n"], **kwargs)
total = np.sum(a["total"], **kwargs)
dtype = dict(kwargs.pop("dtype"))
n = np.sum(a["n"], dtype=dtype["n"], **kwargs)
total = np.sum(a["total"], dtype=dtype["total"], **kwargs)
return {"n": n, "total": total}


Expand Down Expand Up @@ -114,7 +118,15 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False):
dtype = complex128
else:
dtype = x.dtype
return reduction(x, np.prod, axis=axis, dtype=dtype, keepdims=keepdims)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
np.prod,
axis=axis,
dtype=dtype,
keepdims=keepdims,
extra_func_kwargs=extra_func_kwargs,
)


def sum(x, /, *, axis=None, dtype=None, keepdims=False):
Expand All @@ -131,4 +143,12 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False):
dtype = complex128
else:
dtype = x.dtype
return reduction(x, np.sum, axis=axis, dtype=dtype, keepdims=keepdims)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
np.sum,
axis=axis,
dtype=dtype,
keepdims=keepdims,
extra_func_kwargs=extra_func_kwargs,
)
5 changes: 5 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def blockwise(
new_axes=None,
align_arrays=True,
target_store=None,
extra_func_kwargs=None,
**kwargs,
) -> "Array":
arrays = args[::2]
Expand Down Expand Up @@ -277,6 +278,7 @@ def blockwise(
new_axes=new_axes,
in_names=in_names,
out_name=name,
extra_func_kwargs=extra_func_kwargs,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -712,6 +714,7 @@ def reduction(
intermediate_dtype=None,
dtype=None,
keepdims=False,
extra_func_kwargs=None,
) -> "Array":
if combine_func is None:
combine_func = func
Expand Down Expand Up @@ -742,6 +745,7 @@ def reduction(
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

# merge/reduce along axis in multiple rounds until there's a single block in each reduction axis
Expand Down Expand Up @@ -783,6 +787,7 @@ def reduction(
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

if aggegrate_func is not None:
Expand Down
7 changes: 6 additions & 1 deletion cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def blockwise(
in_names: Optional[List[str]] = None,
out_name: Optional[str] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""Apply a function across blocks from multiple source Zarr arrays.
Expand Down Expand Up @@ -127,6 +128,9 @@ def blockwise(
extra_projected_mem : int
Extra memory projected to be needed (in bytes) in addition to the memory used reading
the input arrays and writing the output.
extra_func_kwargs : dict
Extra keyword arguments to pass to function that can't be passed as regular keyword arguments
since they clash with other blockwise arguments (such as dtype).
**kwargs : dict
Extra keyword arguments to pass to function
Expand Down Expand Up @@ -197,7 +201,8 @@ def blockwise(
shape, dtype=dtype, chunks=chunksize, store=target_store
)

func_with_kwargs = partial(func, **kwargs)
func_kwargs = extra_func_kwargs or {}
func_with_kwargs = partial(func, **{**kwargs, **func_kwargs})
read_proxies = {
name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items()
}
Expand Down
3 changes: 2 additions & 1 deletion cubed/tests/test_gufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def spec(tmp_path):
@pytest.mark.parametrize("vectorize", [False, True])
def test_apply_reduction(spec, vectorize):
def stats(x):
return np.mean(x, axis=-1)
# note dtype matches output_dtypes in apply_gufunc below
return np.mean(x, axis=-1, dtype=np.float32)

r = np.random.normal(size=(10, 20, 30))
a = cubed.from_array(r, chunks=(5, 5, 30), spec=spec)
Expand Down

0 comments on commit d19caec

Please sign in to comment.