Skip to content

Make create_array signatures consistent #2819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changes/2819.chore.rst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since "chore type" changelong entries don't get rendered (e.g., see https://zarr.readthedocs.io/en/stable/release-notes.html#misc), I'd recommend splitting this up into a "feature" entry for the updated signatures, and a separate "feature" for the change in default fill_value arguments.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Ensure that invocations of ``create_array`` use consistent keyword arguments, with consistent defaults.
Specifically, ``zarr.api.synchronous.create_array`` now takes a ``write_data`` keyword argument; The
``create_array`` method on ``zarr.Group`` takes ``data`` and ``write_data`` keyword arguments. The ``fill_value``
keyword argument of the various invocations of ``create_array`` has been consistently set to ``None``, where previously it was either ``None`` or ``0``.
4 changes: 2 additions & 2 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,10 +860,10 @@ async def open_group(
async def create(
shape: ChunkCoords | int,
*, # Note: this is a change from v2
chunks: ChunkCoords | int | None = None, # TODO: v2 allowed chunks=True
chunks: ChunkCoords | int | bool | None = None,
dtype: ZDTypeLike | None = None,
compressor: CompressorLike = "auto",
fill_value: Any | None = 0, # TODO: need type
fill_value: Any | None = None, # TODO: need type
order: MemoryOrder | None = None,
store: str | StoreLike | None = None,
synchronizer: Any | None = None,
Expand Down
48 changes: 35 additions & 13 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike
from zarr.core.common import MemoryOrder
from zarr.core.dtype import ZDTypeLike

logger = logging.getLogger("zarr.group")

Expand Down Expand Up @@ -999,22 +1000,24 @@ async def create_array(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike,
shape: ShapeLike | None = None,
dtype: ZDTypeLike | None = None,
data: np.ndarray[Any, np.dtype[Any]] | None = None,
chunks: ChunkCoords | Literal["auto"] = "auto",
shards: ShardsLike | None = None,
filters: FiltersLike = "auto",
compressors: CompressorsLike = "auto",
compressor: CompressorLike = "auto",
serializer: SerializerLike = "auto",
fill_value: Any | None = 0,
fill_value: Any | None = None,
order: MemoryOrder | None = None,
attributes: dict[str, JSON] | None = None,
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
dimension_names: DimensionNames = None,
storage_options: dict[str, Any] | None = None,
overwrite: bool = False,
config: ArrayConfig | ArrayConfigLike | None = None,
config: ArrayConfigLike | None = None,
write_data: bool = True,
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
"""Create an array within this group.

Expand Down Expand Up @@ -1102,6 +1105,11 @@ async def create_array(
Whether to overwrite an array with the same name in the store, if one exists.
config : ArrayConfig or ArrayConfigLike, optional
Runtime configuration for the array.
write_data : bool
If a pre-existing array-like object was provided to this function via the ``data`` parameter
then ``write_data`` determines whether the values in that array-like object should be
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
array will be left empty.

Returns
-------
Expand All @@ -1116,6 +1124,7 @@ async def create_array(
name=name,
shape=shape,
dtype=dtype,
data=data,
chunks=chunks,
shards=shards,
filters=filters,
Expand All @@ -1130,6 +1139,7 @@ async def create_array(
storage_options=storage_options,
overwrite=overwrite,
config=config,
write_data=write_data,
)

@deprecated("Use AsyncGroup.create_array instead.")
Expand Down Expand Up @@ -2411,22 +2421,24 @@ def create_array(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike,
shape: ShapeLike | None = None,
dtype: ZDTypeLike | None = None,
data: np.ndarray[Any, np.dtype[Any]] | None = None,
chunks: ChunkCoords | Literal["auto"] = "auto",
shards: ShardsLike | None = None,
filters: FiltersLike = "auto",
compressors: CompressorsLike = "auto",
compressor: CompressorLike = "auto",
serializer: SerializerLike = "auto",
fill_value: Any | None = 0,
order: MemoryOrder | None = "C",
fill_value: Any | None = None,
order: MemoryOrder | None = None,
attributes: dict[str, JSON] | None = None,
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
dimension_names: DimensionNames = None,
storage_options: dict[str, Any] | None = None,
overwrite: bool = False,
config: ArrayConfig | ArrayConfigLike | None = None,
config: ArrayConfigLike | None = None,
write_data: bool = True,
) -> Array:
"""Create an array within this group.

Expand All @@ -2437,10 +2449,13 @@ def create_array(
name : str
The name of the array relative to the group. If ``path`` is ``None``, the array will be located
at the root of the store.
shape : ChunkCoords
Shape of the array.
dtype : npt.DTypeLike
Data type of the array.
shape : ChunkCoords, optional
Shape of the array. Can be ``None`` if ``data`` is provided.
dtype : npt.DTypeLike | None
Data type of the array. Can be ``None`` if ``data`` is provided.
data : Array-like data to use for initializing the array. If this parameter is provided, the
``shape`` and ``dtype`` parameters must be identical to ``data.shape`` and ``data.dtype``,
or ``None``.
chunks : ChunkCoords, optional
Chunk shape of the array.
If not specified, default are guessed based on the shape and dtype.
Expand Down Expand Up @@ -2514,6 +2529,11 @@ def create_array(
Whether to overwrite an array with the same name in the store, if one exists.
config : ArrayConfig or ArrayConfigLike, optional
Runtime configuration for the array.
write_data : bool
If a pre-existing array-like object was provided to this function via the ``data`` parameter
then ``write_data`` determines whether the values in that array-like object should be
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
array will be left empty.

Returns
-------
Expand All @@ -2528,6 +2548,7 @@ def create_array(
name=name,
shape=shape,
dtype=dtype,
data=data,
chunks=chunks,
shards=shards,
fill_value=fill_value,
Expand All @@ -2541,6 +2562,7 @@ def create_array(
overwrite=overwrite,
storage_options=storage_options,
config=config,
write_data=write_data,
)
)
)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import inspect
import pathlib
import re
from typing import TYPE_CHECKING

Expand All @@ -8,6 +10,7 @@

if TYPE_CHECKING:
import pathlib
from collections.abc import Callable

from zarr.abc.store import Store
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
Expand Down Expand Up @@ -1216,6 +1219,43 @@ def test_open_array_with_mode_r_plus(store: Store, zarr_format: ZarrFormat) -> N
z2[:] = 3


@pytest.mark.parametrize(
("a_func", "b_func"),
[
(zarr.api.asynchronous.create_array, zarr.api.synchronous.create_array),
(zarr.api.asynchronous.save, zarr.api.synchronous.save),
(zarr.api.asynchronous.save_array, zarr.api.synchronous.save_array),
(zarr.api.asynchronous.save_group, zarr.api.synchronous.save_group),
(zarr.api.asynchronous.open_group, zarr.api.synchronous.open_group),
(zarr.api.asynchronous.create, zarr.api.synchronous.create),
],
)
def test_consistent_signatures(
a_func: Callable[[object], object], b_func: Callable[[object], object]
) -> None:
"""
Ensure that pairs of functions have the same signature
"""
base_sig = inspect.signature(a_func)
test_sig = inspect.signature(b_func)
wrong: dict[str, list[object]] = {
"missing_from_test": [],
"missing_from_base": [],
"wrong_type": [],
}
for key, value in base_sig.parameters.items():
if key not in test_sig.parameters:
wrong["missing_from_test"].append((key, value))
for key, value in test_sig.parameters.items():
if key not in base_sig.parameters:
wrong["missing_from_base"].append((key, value))
if base_sig.parameters[key] != value:
wrong["wrong_type"].append({key: {"test": value, "base": base_sig.parameters[key]}})
assert wrong["missing_from_base"] == []
assert wrong["missing_from_test"] == []
assert wrong["wrong_type"] == []


def test_api_exports() -> None:
"""
Test that the sync API and the async API export the same objects
Expand Down
56 changes: 56 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,43 @@ def test_auto_partition_auto_shards(
assert auto_shards == expected_shards


def test_chunks_and_shards() -> None:
store = StorePath(MemoryStore())
shape = (100, 100)
chunks = (5, 5)
shards = (10, 10)

arr_v3 = zarr.create_array(store=store / "v3", shape=shape, chunks=chunks, dtype="i4")
assert arr_v3.chunks == chunks
assert arr_v3.shards is None

arr_v3_sharding = zarr.create_array(
store=store / "v3_sharding",
shape=shape,
chunks=chunks,
shards=shards,
dtype="i4",
)
assert arr_v3_sharding.chunks == chunks
assert arr_v3_sharding.shards == shards

arr_v2 = zarr.create_array(
store=store / "v2", shape=shape, chunks=chunks, zarr_format=2, dtype="i4"
)
assert arr_v2.chunks == chunks
assert arr_v2.shards is None


@pytest.mark.parametrize("store", ["memory"], indirect=True)
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize(
("dtype", "fill_value_expected"), [("<U4", ""), ("<S4", b""), ("i", 0), ("f", 0.0)]
)
def test_default_fill_value(dtype: str, fill_value_expected: object, store: Store) -> None:
a = zarr.create_array(store, shape=(5,), chunks=(5,), dtype=dtype)
assert a.fill_value == fill_value_expected


@pytest.mark.parametrize("store", ["memory"], indirect=True)
class TestCreateArray:
@staticmethod
Expand Down Expand Up @@ -1747,6 +1784,25 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser
assert all(np.array_equal(r, data) for r in results)


def test_create_array_method_signature() -> None:
"""
Test that the signature of the ``AsyncGroup.create_array`` function has nearly the same signature
as the ``create_array`` function. ``AsyncGroup.create_array`` should take all of the same keyword
arguments as ``create_array`` except ``store``.
"""

base_sig = inspect.signature(create_array)
meth_sig = inspect.signature(AsyncGroup.create_array)
# ignore keyword arguments that are either missing or have different semantics when
# create_array is invoked as a group method
ignore_kwargs = {"zarr_format", "store", "name"}
# TODO: make this test stronger. right now, it only checks that all the parameters in the
# function signature are used in the method signature. we can be more strict and check that
# the method signature uses no extra parameters.
base_params = dict(filter(lambda kv: kv[0] not in ignore_kwargs, base_sig.parameters.items()))
assert (set(base_params.items()) - set(meth_sig.parameters.items())) == set()


async def test_sharding_coordinate_selection() -> None:
store = MemoryStore()
g = zarr.open_group(store, mode="w")
Expand Down
18 changes: 17 additions & 1 deletion tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,7 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None:
@pytest.mark.parametrize(
("a_func", "b_func"),
[
(zarr.core.group.AsyncGroup.create_array, zarr.core.group.Group.create_array),
(zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy),
(zarr.core.group.create_hierarchy, zarr.core.sync_group.create_hierarchy),
(zarr.core.group.create_nodes, zarr.core.sync_group.create_nodes),
Expand All @@ -1546,7 +1547,22 @@ def test_consistent_signatures(
"""
base_sig = inspect.signature(a_func)
test_sig = inspect.signature(b_func)
assert test_sig.parameters == base_sig.parameters
wrong: dict[str, list[object]] = {
"missing_from_test": [],
"missing_from_base": [],
"wrong_type": [],
}
for key, value in base_sig.parameters.items():
if key not in test_sig.parameters:
wrong["missing_from_test"].append((key, value))
for key, value in test_sig.parameters.items():
if key not in base_sig.parameters:
wrong["missing_from_base"].append((key, value))
if base_sig.parameters[key] != value:
wrong["wrong_type"].append({key: {"test": value, "base": base_sig.parameters[key]}})
assert wrong["missing_from_base"] == []
assert wrong["missing_from_test"] == []
assert wrong["wrong_type"] == []


@pytest.mark.parametrize("store", ["memory"], indirect=True)
Expand Down
Loading