Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax integration #304

Open
alxmrs opened this issue Sep 10, 2023 · 10 comments
Open

Jax integration #304

alxmrs opened this issue Sep 10, 2023 · 10 comments

Comments

@alxmrs
Copy link
Contributor

alxmrs commented Sep 10, 2023

Can the core array API ops of cubed be implemented in jax, s.t. everything easily compiles to accelerators? Could this solve the common pain point of running out of GPU memory? How would other constraints (GPU bandwidth limits) be handled? What is the ideal distributed runtime environment to make the most of this? Could spot GPU instances be used (serverless accelerators)?

@tomwhite
Copy link
Member

Thanks @alxmrs for opening this issue. I'm not familiar enough with JAX or GPUs to answer these questions, but I'd be happy to support or discuss an initiative in this direction. Is there a small piece of work that you have in mind that could be used to explore this?

@alxmrs
Copy link
Contributor Author

alxmrs commented Sep 19, 2023

The Jax docs may provide a few good toy examples useful to validate this idea.

Check out this tutorial: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

This tutorial on distributing computation on a pool of TPUs (specifically, the neural net section) may be of interest, too:

https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#examples-neural-networks

High level goals for Jax + Cubed may be to make managing GPU memory effortless:

@tomwhite
Copy link
Member

Thanks for the pointers @alxmrs.

Thinking about how Cubed might hook into this, the main idea in Cubed is that every array is backed by Zarr, and an operation maps one Zarr array to another, by working in parallel on chunks (see https://tom-e-white.com/cubed/cubed.slides.html#/1).

Every operation is ultimately expressible as a blockwise operation (or a rechunk operation, but let's ignore that here), which ends up in the apply_blockwise function:

https://github.com/tomwhite/cubed/blob/0d13e4f2b12c22d1c41b9f4ea693266b21d808d0/cubed/primitive/blockwise.py#L53-L75

The key parts are:

  1. Reading each arg from Zarr into (CPU) memory (line 67),
  2. Invoking the function on the args (line 70), and
  3. Writing the result from (CPU) memory back to Zarr (line 73 or 75)

To change this to use JAX, we'd have to 1. read from Zarr into JAX arrays, 2. invoke the relevant JAX function on the arrays, 3. write the resulting JAX array to Zarr.

In fact, there might not be anything to do for 2., since you could call cubed.map_blocks with a JAX function.

This might be enough for the FFT example, although I'm a bit hazy on if any post-processing is needed on the chunked (sharded) output.

A final thought. Is KvikIO, which does direct Zarr-GPU IO, related to this?

@TomNicholas
Copy link
Member

read from Zarr into JAX arrays

A final thought. Is KvikIO, which does direct Zarr-GPU IO, related to this?

Reading the xarray blog post on this that @dcherian and @weiji14 wrote, it seems they used a Zarr store provided by kvikio. I expect cubed could use this to load data from Zarr direct to GPU in the form of a cupy array, which would be cool. (Or even you could probably use the xarray backend they wrote alongside cubed-xarray to achieve this.)

I tried to find if there was anything similar for JAX, but didn't see anything (only this jax-ml/jax#17534). Writing from JAX to tensorstore was done for checkpointing language models (https://blog.research.google/2022/09/tensorstore-for-high-performance.html?m=1) but one would have thought that making tensorstore return JAX arrays directly would have been tried...

@weiji14
Copy link

weiji14 commented Feb 11, 2024

KvikIO loads data into cupy, but it should technically be possible to zero-copy cupy arrays to JAX, Pytorch, or any array library that implements conversion via dlpack or the __cuda_array_interface__ protocol. It looks like JAX supports this already (jax-ml/jax#1100)? But I haven't tried this end to end yet. There's also NVIDIA DALI which seems to work with JAX (https://docs.nvidia.com/deeplearning/dali/archives/dali_1_32_0/user-guide/docs/plugins/jax_tutorials.html#jax), but the interface is a little less convenient since you need to setup a pipeline. Generally, the integration between RAPIDS AI libraries (which builds on cupy) is a bit better on the Pytorch side with the RAPIDS Memory Manager (https://github.com/rapidsai/rmm/blob/branch-24.04/README.md#using-rmm-with-third-party-libraries).

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 23, 2024

Hey @tomwhite, I have a question for you: to run jax arrays on accelerators (M1+ chips, GPUs, TPUs, etc.), someone needs to call jax.jit: https://jax.readthedocs.io/en/latest/quickstart.html#just-in-time-compilation-with-jax-jit

Where is a good place to make this kind of call within Cubed? Is this something that should be handled by an Executor (this seems not so ideal)?

(Here's some more-in-depth docs on Jax's jit: https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation).

A related concern that I haven't properly figured out yet: How should this intersect with devices and sharding?
https://jax.readthedocs.io/en/latest/sharded-computation.html

@tomwhite
Copy link
Member

Where is a good place to make this kind of call within Cubed? Is this something that should be handled by an Executor (this seems not so ideal)?

Possibly as a part of DAG finalization, after (Cubed) optimization has been run. Although the function you want to jit will be the function in BlockwiseSpec:

class BlockwiseSpec:
"""Specification for how to run blockwise on an array.
This is similar to ``CopySpec`` in rechunker.
Attributes
----------
key_function : Callable
A function that maps an output chunk key to one or more input chunk keys.
function : Callable
A function that maps input chunks to an output chunk.
function_nargs: int
The number of array arguments that ``function`` takes.
num_input_blocks: Tuple[int, ...]
The number of input blocks read from each input array.
reads_map : Dict[str, CubedArrayProxy]
Read proxy dictionary keyed by array name.
write : CubedArrayProxy
Write proxy with an ``array`` attribute that supports ``__setitem__``.
"""
key_function: Callable[..., Any]
function: Callable[..., Any]
function_nargs: int
num_input_blocks: Tuple[int, ...]
reads_map: Dict[str, CubedArrayProxy]
write: CubedArrayProxy

What's the simplest possible example to start with?

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 23, 2024

Thanks for your suggestion, Tom. I've prototyped something here: alxmrs#1

For now, it looks like I need to work on landing the M1 PR before I can take this any further.

@tomwhite
Copy link
Member

Nice!

@tomwhite
Copy link
Member

I think compiling the (Cubed optimized) blockwise functions using AOT compilation (as you mentioned in #490 (comment)), and then exporting them so they can run in other processes (https://jax.readthedocs.io/en/latest/export/export.html) may be the way to go. Perhaps this is worth trying on CPU first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants