diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 39121fb2..17e1e7d4 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -583,45 +583,25 @@ def merged_chunk_len_for_indexer(ia, c): target_chunks = normalize_chunks(chunks, shape, dtype=dtype) - if _is_basic_selection(idx): - # use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct + # use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct - def selection_function(out_key): - out_coords = out_key[1:] - return _target_chunk_selection(target_chunks, out_coords, selection) + def selection_function(out_key): + out_coords = out_key[1:] + return _target_chunk_selection(target_chunks, out_coords, selection) - max_num_input_blocks = _index_num_input_blocks( - idx, x.chunksize, out_chunksizes, x.numblocks - ) + max_num_input_blocks = _index_num_input_blocks( + idx, x.chunksize, out_chunksizes, x.numblocks + ) - out = map_selection( - None, # no function to apply after selection - selection_function, - x, - shape, - x.dtype, - target_chunks, - max_num_input_blocks=max_num_input_blocks, - ) - else: - # use map_direct, which can't be fused - # (note that it should be possible to re-write as general_blockwise with more work) - - # memory allocated by reading one chunk from input array - # note that although the output chunk will overlap multiple input chunks, zarr will - # read the chunks in series, reusing the buffer - extra_projected_mem = x.chunkmem - - out = map_direct( - _read_index_chunk, - x, - shape=shape, - dtype=dtype, - chunks=target_chunks, - extra_projected_mem=extra_projected_mem, - target_chunks=target_chunks, - selection=selection, - ) + out = map_selection( + None, # no function to apply after selection + selection_function, + x, + shape, + x.dtype, + target_chunks, + max_num_input_blocks=max_num_input_blocks, + ) # merge chunks for any dims with step > 1 so they are # the same size as the input (or slightly smaller due to rounding) @@ -641,10 +621,6 @@ def selection_function(out_key): return out -def _is_basic_selection(idx: ndindex.Tuple): - return all(isinstance(ia, (ndindex.Integer, ndindex.Slice)) for ia in idx.args) - - def _index_num_input_blocks( idx: ndindex.Tuple, in_chunksizes, out_chunksizes, numblocks ): @@ -661,21 +637,27 @@ def _index_num_input_blocks( # step is not a multiple of chunk size, and output chunks have more than one element # so some output chunks will access two input chunks num *= 2 + elif isinstance(ia, ndindex.IntegerArray): + # in the worse case, elements could be retrieved from all blocks + # TODO: improve to calculate the actual max input blocks + num *= nb else: - raise NotImplementedError("Only integer or slice indexes are supported.") + raise NotImplementedError( + "Only integer, slice, or int array indexes are supported." + ) return num -def create_basic_indexer(selection, shape, chunks): +def _create_zarr_indexer(selection, shape, chunks): if zarr.__version__[0] == "3": from zarr.core.chunk_grids import RegularChunkGrid - from zarr.core.indexing import BasicIndexer + from zarr.core.indexing import OrthogonalIndexer - return BasicIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks)) + return OrthogonalIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks)) else: - from zarr.indexing import BasicIndexer + from zarr.indexing import OrthogonalIndexer - return BasicIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks)) + return OrthogonalIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks)) @dataclass @@ -706,8 +688,8 @@ def _assemble_index_chunk( out_coords = block_id in_sel = selection_function(("out",) + out_coords) - # use a Zarr BasicIndexer to convert this to input coordinates - indexer = create_basic_indexer(in_sel, in_shape, in_chunksize) + # use a Zarr indexer to convert this to input coordinates + indexer = _create_zarr_indexer(in_sel, in_shape, in_chunksize) shape = indexer.shape out = np.empty(shape, dtype=dtype) @@ -793,8 +775,8 @@ def key_function(out_key): # compute the selection on x required to get the relevant chunk for out_key in_sel = selection_function(out_key) - # use a Zarr BasicIndexer to convert selection to input coordinates - indexer = create_basic_indexer(in_sel, x.shape, x.chunksize) + # use a Zarr indexer to convert selection to input coordinates + indexer = _create_zarr_indexer(in_sel, x.shape, x.chunksize) return ( iter(tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer)),