Skip to content

Commit

Permalink
Adding Jit functionality to Plan finialization.
Browse files Browse the repository at this point in the history
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators.

cubed-dev#508 needs to be merged first.
  • Loading branch information
alxmrs committed Jul 23, 2024
1 parent 6b6dc4f commit 2f13825
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
35 changes: 33 additions & 2 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import inspect
import tempfile
import uuid
Expand All @@ -21,6 +22,8 @@

sym_counter = 0

Decorator = Callable[[Callable], Callable]


def gensym(name="op"):
global sym_counter
Expand Down Expand Up @@ -182,13 +185,40 @@ def _create_lazy_zarr_arrays(self, dag):

return dag

def _compile_blockwise(self, dag, jit_function: Decorator) -> nx.MultiDiGraph:
"""JIT-compiles the functions from all blockwise ops by mutating the input dag."""
# Recommended: make a copy of the dag before calling this function.
for n in dag.nodes:
node = dag.nodes[n]

if "primitive_op" not in node:
continue

if not isinstance(node["pipeline"].config, BlockwiseSpec):
continue

# node is a blockwise primitive_op.
# maybe we should investigate some sort of optics library for frozen dataclasses...
new_pipeline = dataclasses.replace(
node["pipeline"],
config=dataclasses.replace(
node["pipeline"].config,
function=jit_function(node["pipeline"].config.function)
)
)
node["pipeline"] = new_pipeline

return dag

@lru_cache
def _finalize_dag(
self, optimize_graph: bool = True, optimize_function=None
self, optimize_graph: bool = True, optimize_function=None, jit_function: Optional[Decorator] = None,
) -> nx.MultiDiGraph:
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(jit_function):
dag = self._compile_blockwise(dag, jit_function)
dag = self._create_lazy_zarr_arrays(dag)
return nx.freeze(dag)

Expand All @@ -198,11 +228,12 @@ def execute(
callbacks=None,
optimize_graph=True,
optimize_function=None,
jit_function=None,
resume=None,
spec=None,
**kwargs,
):
dag = self._finalize_dag(optimize_graph, optimize_function)
dag = self._finalize_dag(optimize_graph, optimize_function, jit_function)

compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"

Expand Down
27 changes: 27 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import os
import platform

import fsspec
Expand Down Expand Up @@ -264,3 +265,29 @@ def test_check_runtime_memory_modal(spec, modal_executor):
match=r"Runtime memory \(2097152000\) is less than allowed_mem \(4000000000\)",
):
c.compute(executor=modal_executor)


JIT_FUNCTIONS = [lambda fn: fn]

try:
from numba import jit as numba_jit
JIT_FUNCTIONS.append(numba_jit)
except ModuleNotFoundError:
pass

try:
if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''):
from jax import jit as jax_jit
JIT_FUNCTIONS.append(jax_jit)
except ModuleNotFoundError:
pass


@pytest.mark.parametrize("jit_function", JIT_FUNCTIONS)
def test_check_jit_compliation(spec, executor, jit_function):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
assert_array_equal(
c.compute(executor=executor, jit_function=jit_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]])
)

0 comments on commit 2f13825

Please sign in to comment.