Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Icechunk store #633

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions cubed/icechunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import TYPE_CHECKING, Any, List, Sequence, Union

import zarr
from icechunk import IcechunkStore

from cubed import compute
from cubed.core.array import CoreArray
from cubed.core.ops import blockwise
from cubed.runtime.types import Callback

if TYPE_CHECKING:
from cubed.array_api.array_object import Array


def store_icechunk(
store: IcechunkStore,
*,
sources: Union["Array", Sequence["Array"]],
targets: List[zarr.Array],
executor=None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could definitely address this later but regions is quite an important kwarg.

**kwargs: Any,
) -> None:
if isinstance(sources, CoreArray):
sources = [sources]
targets = [targets]

if any(not isinstance(s, CoreArray) for s in sources):
raise ValueError("All sources must be cubed array objects")

if len(sources) != len(targets):
raise ValueError(
f"Different number of sources ({len(sources)}) and targets ({len(targets)})"
)

if isinstance(sources, CoreArray):
sources = [sources]
targets = [targets]

arrays = []
for source, target in zip(sources, targets):
identity = lambda a: a
ind = tuple(range(source.ndim))
array = blockwise(
identity,
ind,
source,
ind,
dtype=source.dtype,
align_arrays=False,
target_store=target,
return_writes_stores=True,
)
arrays.append(array)

# use a callback to merge icechunk stores
store_callback = IcechunkStoreCallback()
# add to other callbacks the user may have set
callbacks = kwargs.pop("callbacks", [])
callbacks = [store_callback] + list(callbacks)

compute(
*arrays,
executor=executor,
_return_in_memory_array=False,
callbacks=callbacks,
**kwargs,
)

# merge back into the store passed into this function
merged_store = store_callback.store
store.merge(merged_store.change_set_bytes())


class IcechunkStoreCallback(Callback):
def on_compute_start(self, event):
self.store = None

def on_task_end(self, event):
result = event.result
if result is None:
return
for store in result:
if self.store is None:
self.store = store
else:
self.store.merge(store.change_set_bytes())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomwhite It looks like callbacks are "accumulated" on every worker. Is that right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcherian My only concern with it is the implementation of merge_stores is probably very slow, using no parallelism, but I don't see any issues with making it public

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right we can fix that, but note that dask and presumably cubed are accumulating on remote workers already, so there is already some parallelism in how it is used.

Moreover it'd be nice not to have store.change_set_bytes be the public API :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like callbacks are "accumulated" on every worker. Is that right?

No, they are being accumulated on the client so there is no parallelism. I don't know what Icechunk is doing here, but would it be possible to merge a batch in one go rather than one at a time? Could that be more efficient?

Copy link

@dcherian dcherian Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but would it be possible to merge a batch in one go rather than one at a time

we'll have to build some rust API, we'll get to this eventually.

No, they are being accumulated on the client so there is no parallelism.

ah ok. In that case, how about calling reduction on each array that was written? That way you parallelize the merge across blocks for each array, and then the only serial bit is the merging across arrays, which will be a lot smaller. I considered this approach for dask, but then just wrote out a tree reduction across all chunks.

EDIT: or is the reduction approach not viable because you need to serialize to Zarr at some point?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or is the reduction approach not viable because you need to serialize to Zarr at some point?

Yes, that's basically it - Cubed also separates the data paths for array manipulations (contents of the blocks) from the metadata operations (block IDs and - for Icechunk - the changesets). So I think merging in batches would be more feasible.

we'll have to build some rust API, we'll get to this eventually.

+1

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's basically it

:( I was afraid so.

10 changes: 10 additions & 0 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class BlockwiseSpec:
iterable_input_blocks: Tuple[bool, ...]
reads_map: Dict[str, CubedArrayProxy]
writes_list: List[CubedArrayProxy]
return_writes_stores: bool = False


def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None:
Expand Down Expand Up @@ -100,6 +101,9 @@ def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None:
result = backend_array_to_numpy_array(result)
config.writes_list[i].open()[out_chunk_key] = result

if config.return_writes_stores:
return [write_proxy.open().store for write_proxy in config.writes_list]


def get_results_in_different_scope(out_coords: List[int], *, config: BlockwiseSpec):
# wrap function call in a function so that args go out of scope (and free memory) as soon as results are returned
Expand Down Expand Up @@ -267,6 +271,7 @@ def general_blockwise(
function_nargs: Optional[int] = None,
num_input_blocks: Optional[Tuple[int, ...]] = None,
iterable_input_blocks: Optional[Tuple[bool, ...]] = None,
return_writes_stores: bool = False,
**kwargs,
) -> PrimitiveOperation:
"""A more general form of ``blockwise`` that uses a function to specify the block
Expand Down Expand Up @@ -367,6 +372,7 @@ def general_blockwise(
iterable_input_blocks,
read_proxies,
write_proxies,
return_writes_stores,
)

# calculate projected memory
Expand Down Expand Up @@ -536,6 +542,7 @@ def fused_func(*args):
function_nargs = pipeline1.config.function_nargs
read_proxies = pipeline1.config.reads_map
write_proxies = pipeline2.config.writes_list
return_writes_stores = pipeline2.config.return_writes_stores
num_input_blocks = tuple(
n * pipeline2.config.num_input_blocks[0]
for n in pipeline1.config.num_input_blocks
Expand All @@ -549,6 +556,7 @@ def fused_func(*args):
iterable_input_blocks,
read_proxies,
write_proxies,
return_writes_stores,
)

source_array_names = primitive_op1.source_array_names
Expand Down Expand Up @@ -681,6 +689,7 @@ def fuse_blockwise_specs(
for bws in predecessor_bw_specs:
read_proxies.update(bws.reads_map)
write_proxies = bw_spec.writes_list
return_writes_stores = bw_spec.return_writes_stores
return BlockwiseSpec(
fused_key_func,
fused_func,
Expand All @@ -689,6 +698,7 @@ def fuse_blockwise_specs(
fused_iterable_input_blocks,
read_proxies,
write_proxies,
return_writes_stores,
)


Expand Down
87 changes: 87 additions & 0 deletions cubed/tests/test_icechunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Iterable

import icechunk
import numpy as np
import pytest
import zarr
from numpy.testing import assert_array_equal

import cubed
import cubed.array_api as xp
import cubed.random
from cubed.icechunk import store_icechunk
from cubed.tests.utils import MAIN_EXECUTORS


@pytest.fixture(
scope="module",
params=MAIN_EXECUTORS,
ids=[executor.name for executor in MAIN_EXECUTORS],
)
def executor(request):
return request.param


def create_icechunk(a, tmp_path, /, *, dtype=None, chunks=None):
# from dask.asarray
if not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
a = np.asarray(a, dtype=dtype)
if dtype is None:
dtype = a.dtype

store = icechunk.IcechunkStore.create(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
config=icechunk.StoreConfig(inline_chunk_threshold_bytes=1),
read_only=False,
)

group = zarr.group(store=store, overwrite=True)
arr = group.create_array("a", shape=a.shape, chunk_shape=chunks, dtype=dtype)

arr[...] = a

store.commit("commit 1")


def test_from_zarr_icechunk(tmp_path, executor):
create_icechunk(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
tmp_path,
chunks=(2, 2),
)

store = icechunk.IcechunkStore.open_existing(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
)

a = cubed.from_zarr(store, path="a")
assert_array_equal(
a.compute(executor=executor), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
)


def test_store_icechunk(tmp_path, executor):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))

store = icechunk.IcechunkStore.create(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
config=icechunk.StoreConfig(inline_chunk_threshold_bytes=1),
read_only=False,
)
with store.preserve_read_only():
group = zarr.group(store=store, overwrite=True)
target = group.create_array(
"a", shape=a.shape, chunk_shape=a.chunksize, dtype=a.dtype
)
store_icechunk(store, sources=a, targets=target, executor=executor)
store.commit("commit 1")

# reopen store and check contents of array
store = icechunk.IcechunkStore.open_existing(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
)
group = zarr.open_group(store=store, mode="r")
assert_array_equal(
cubed.from_array(group["a"])[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
)
Loading