Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added Auto Opt, GPU and jax.Array #26

Merged
Prev Previous commit
Made a comment about the memory allocation of the output arrays.
philip-paul-mueller committed Oct 4, 2024
commit 1f48a1eccbf525281c9473983a96881b8ca90563
7 changes: 7 additions & 0 deletions src/jace/translated_jaxpr_sdfg.py
Original file line number Diff line number Diff line change
@@ -163,6 +163,9 @@ def __call__(
in_val = in_val.__array__() # noqa: PLW2901 [redefined-loop-name] # JAX arrays do not expose the __array_interface__.
sdfg_call_args[in_name] = in_val

# Allocate the output arrays.
# In DaCe the output arrays are created by the `CompiledSDFG` calls and all
# calls share the same arrays. In JaCe the output arrays are distinct.
arrays = self.sdfg.arrays
for output_name in self.output_names:
sdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name])
@@ -178,6 +181,10 @@ def __call__(
dace.Config.set("compiler", "allow_view_arguments", value=True)
self.compiled_sdfg(**sdfg_call_args)

# DaCe writes the results either into CuPy or NumPy arrays. For compatibility
# with JAX we will now turn them into `jax.Array`s. Note that this is safe
# because we created these arrays in this function explicitly. Thus when
# this function ends, there is no writable reference to these arrays left.
return [
util.move_into_jax_array(sdfg_call_args[output_name])
for output_name in self.output_names