diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 9a448ac9..b4ff7694 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -35,6 +35,7 @@ def mean(x, /, *, axis=None, keepdims=False): # outputs. dtype = x.dtype intermediate_dtype = [("n", np.int64), ("total", np.float64)] + extra_func_kwargs = dict(dtype=intermediate_dtype) return reduction( x, _mean_func, @@ -44,18 +45,21 @@ def mean(x, /, *, axis=None, keepdims=False): intermediate_dtype=intermediate_dtype, dtype=dtype, keepdims=keepdims, + extra_func_kwargs=extra_func_kwargs, ) def _mean_func(a, **kwargs): - n = _numel(a, **kwargs) - total = np.sum(a, **kwargs) + dtype = dict(kwargs.pop("dtype")) + n = _numel(a, dtype=dtype["n"], **kwargs) + total = np.sum(a, dtype=dtype["total"], **kwargs) return {"n": n, "total": total} def _mean_combine(a, **kwargs): - n = np.sum(a["n"], **kwargs) - total = np.sum(a["total"], **kwargs) + dtype = dict(kwargs.pop("dtype")) + n = np.sum(a["n"], dtype=dtype["n"], **kwargs) + total = np.sum(a["total"], dtype=dtype["total"], **kwargs) return {"n": n, "total": total} @@ -114,7 +118,15 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False): dtype = complex128 else: dtype = x.dtype - return reduction(x, np.prod, axis=axis, dtype=dtype, keepdims=keepdims) + extra_func_kwargs = dict(dtype=dtype) + return reduction( + x, + np.prod, + axis=axis, + dtype=dtype, + keepdims=keepdims, + extra_func_kwargs=extra_func_kwargs, + ) def sum(x, /, *, axis=None, dtype=None, keepdims=False): @@ -131,4 +143,12 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): dtype = complex128 else: dtype = x.dtype - return reduction(x, np.sum, axis=axis, dtype=dtype, keepdims=keepdims) + extra_func_kwargs = dict(dtype=dtype) + return reduction( + x, + np.sum, + axis=axis, + dtype=dtype, + keepdims=keepdims, + extra_func_kwargs=extra_func_kwargs, + ) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 97bc2643..1a4b1b78 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -186,6 +186,7 @@ def blockwise( new_axes=None, align_arrays=True, target_store=None, + extra_func_kwargs=None, **kwargs, ) -> "Array": arrays = args[::2] @@ -277,6 +278,7 @@ def blockwise( new_axes=new_axes, in_names=in_names, out_name=name, + extra_func_kwargs=extra_func_kwargs, **kwargs, ) plan = Plan._new( @@ -712,6 +714,7 @@ def reduction( intermediate_dtype=None, dtype=None, keepdims=False, + extra_func_kwargs=None, ) -> "Array": if combine_func is None: combine_func = func @@ -742,6 +745,7 @@ def reduction( keepdims=True, dtype=intermediate_dtype, adjust_chunks=adjust_chunks, + extra_func_kwargs=extra_func_kwargs, ) # merge/reduce along axis in multiple rounds until there's a single block in each reduction axis @@ -783,6 +787,7 @@ def reduction( keepdims=True, dtype=intermediate_dtype, adjust_chunks=adjust_chunks, + extra_func_kwargs=extra_func_kwargs, ) if aggegrate_func is not None: diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 0d0bbebd..eba1b3a0 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -98,6 +98,7 @@ def blockwise( in_names: Optional[List[str]] = None, out_name: Optional[str] = None, extra_projected_mem: int = 0, + extra_func_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): """Apply a function across blocks from multiple source Zarr arrays. @@ -127,6 +128,9 @@ def blockwise( extra_projected_mem : int Extra memory projected to be needed (in bytes) in addition to the memory used reading the input arrays and writing the output. + extra_func_kwargs : dict + Extra keyword arguments to pass to function that can't be passed as regular keyword arguments + since they clash with other blockwise arguments (such as dtype). **kwargs : dict Extra keyword arguments to pass to function @@ -197,7 +201,8 @@ def blockwise( shape, dtype=dtype, chunks=chunksize, store=target_store ) - func_with_kwargs = partial(func, **kwargs) + func_kwargs = extra_func_kwargs or {} + func_with_kwargs = partial(func, **{**kwargs, **func_kwargs}) read_proxies = { name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items() } diff --git a/cubed/tests/test_gufunc.py b/cubed/tests/test_gufunc.py index 07e59e33..7944d609 100644 --- a/cubed/tests/test_gufunc.py +++ b/cubed/tests/test_gufunc.py @@ -15,7 +15,8 @@ def spec(tmp_path): @pytest.mark.parametrize("vectorize", [False, True]) def test_apply_reduction(spec, vectorize): def stats(x): - return np.mean(x, axis=-1) + # note dtype matches output_dtypes in apply_gufunc below + return np.mean(x, axis=-1, dtype=np.float32) r = np.random.normal(size=(10, 20, 30)) a = cubed.from_array(r, chunks=(5, 5, 30), spec=spec)