-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax integration #304
Comments
Thanks @alxmrs for opening this issue. I'm not familiar enough with JAX or GPUs to answer these questions, but I'd be happy to support or discuss an initiative in this direction. Is there a small piece of work that you have in mind that could be used to explore this? |
The Jax docs may provide a few good toy examples useful to validate this idea. Check out this tutorial: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html This tutorial on distributing computation on a pool of TPUs (specifically, the neural net section) may be of interest, too: High level goals for Jax + Cubed may be to make managing GPU memory effortless:
|
Thanks for the pointers @alxmrs. Thinking about how Cubed might hook into this, the main idea in Cubed is that every array is backed by Zarr, and an operation maps one Zarr array to another, by working in parallel on chunks (see https://tom-e-white.com/cubed/cubed.slides.html#/1). Every operation is ultimately expressible as a blockwise operation (or a rechunk operation, but let's ignore that here), which ends up in the The key parts are:
To change this to use JAX, we'd have to 1. read from Zarr into JAX arrays, 2. invoke the relevant JAX function on the arrays, 3. write the resulting JAX array to Zarr. In fact, there might not be anything to do for 2., since you could call This might be enough for the FFT example, although I'm a bit hazy on if any post-processing is needed on the chunked (sharded) output. A final thought. Is KvikIO, which does direct Zarr-GPU IO, related to this? |
Reading the xarray blog post on this that @dcherian and @weiji14 wrote, it seems they used a Zarr store provided by kvikio. I expect cubed could use this to load data from Zarr direct to GPU in the form of a cupy array, which would be cool. (Or even you could probably use the xarray backend they wrote alongside cubed-xarray to achieve this.) I tried to find if there was anything similar for JAX, but didn't see anything (only this jax-ml/jax#17534). Writing from JAX to tensorstore was done for checkpointing language models (https://blog.research.google/2022/09/tensorstore-for-high-performance.html?m=1) but one would have thought that making tensorstore return JAX arrays directly would have been tried... |
KvikIO loads data into |
Hey @tomwhite, I have a question for you: to run jax arrays on accelerators (M1+ chips, GPUs, TPUs, etc.), someone needs to call Where is a good place to make this kind of call within Cubed? Is this something that should be handled by an Executor (this seems not so ideal)? (Here's some more-in-depth docs on Jax's jit: https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation). A related concern that I haven't properly figured out yet: How should this intersect with devices and sharding? |
Possibly as a part of DAG finalization, after (Cubed) optimization has been run. Although the function you want to jit will be the function in cubed/cubed/primitive/blockwise.py Lines 47 to 73 in 59c593d
What's the simplest possible example to start with? |
Thanks for your suggestion, Tom. I've prototyped something here: alxmrs#1 For now, it looks like I need to work on landing the M1 PR before I can take this any further. |
Nice! |
I think compiling the (Cubed optimized) blockwise functions using AOT compilation (as you mentioned in #490 (comment)), and then exporting them so they can run in other processes (https://jax.readthedocs.io/en/latest/export/export.html) may be the way to go. Perhaps this is worth trying on CPU first. |
Can the core array API ops of cubed be implemented in jax, s.t. everything easily compiles to accelerators? Could this solve the common pain point of running out of GPU memory? How would other constraints (GPU bandwidth limits) be handled? What is the ideal distributed runtime environment to make the most of this? Could spot GPU instances be used (serverless accelerators)?
The text was updated successfully, but these errors were encountered: