-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Jax tests for the M1 mac. #508
base: main
Are you sure you want to change the base?
Conversation
Locally, I'm hitting a large number of errors. It looks like jax-metal is still highly experimental. 104 failed tests!
|
I've pushed some changes to cut the failed tests down in half locally. I'll definitely need design opinions on my review. The next thing I plan to tackle is randomness, which is a special case for Jax and all GPU acceleration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this @alxmrs. I'm honoured that you are working on Cubed during your round the world trip!
I have a Mac mini M1 so I should be able to try it when you've got it working. Can you post your environment versions and any pointers to getting it set up please?
I’m on mobile at the moment, but in the meantime: I’m using Python 3.11 (compiled for ARM). I’ve followed these instructions to set up jax for the M1 (which specifically means installing https://developer.apple.com/metal/jax/ Right now, I’m trying out |
I'm hitting the same errors faced in #494, namely that
|
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators. cubed-dev#508 needs to be merged first.
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators. cubed-dev#508 needs to be merged first.
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators. cubed-dev#508 needs to be merged first.
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators. cubed-dev#508 needs to be merged first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good. I added a few comments about dtype handling.
The CI workflow is crashing with a seg fault - do you see the same on your machine?
dtype = nxp.arange(start, stop, step * num if num else step).dtype | ||
for k, dtype_ in default_dtypes(device=device).items(): | ||
if nxp.isdtype(dtype, k): | ||
dtype = dtype_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me what this is doing or why it is needed. If nxp is the jax namespace doesn't the call to arange already return the correct dtype (int32) - or does jax metal just return int64 or fail?
It might help to factor out a function to do this (given it is duplicated below too) with a name describing what it does, and perhaps a comment too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or does jax metal just return int64 or fail?
Yes, it looks like jax metal returns 64 bit precisions and fails; this ensures the correct precision.
I've factored this out to a function, good shout.
cubed/tests/test_array_api.py
Outdated
x = np.arange(400, dtype=np.float32).reshape((20, 20)) | ||
a = xp.asarray(x, chunks=(5, 4), dtype=xp.float32) | ||
y = np.arange(200, dtype=np.float32).reshape((20, 10)) | ||
b = xp.asarray(y, chunks=(4, 5), dtype=xp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto - we need to test that leaving out dtype works OK in these cases.
Thanks for the review! I’ll address these when I can.
No! Tests are running without any faults locally for me. I’ll do some digging to better understand the CI environment. |
@tomwhite It looks like to use the GPU on the M1 actions, we need to enable a premium action runner: https://github.blog/news-insights/product-news/introducing-the-new-apple-silicon-powered-m1-macos-larger-runner-for-github-actions/ I think this is the cause of the segfault (from what I can tell from related discussions): pytorch/pytorch#111449 (comment). How do you think we should proceed here? Should we attempt to have CI target GPUs, or should I configure Jax to run on the CPU? |
We should certainly have JAX running against the CPU in CI as it tests that Cubed works with the JAX array API. For testing on Mac M1 GPUs, I think we can add that later - particularly since it's a paid for option? It would be good to get the work that you have done here merged, so if you change the CI back then I'm happy to merge it. BTW I tried installing JAX metal on my Mac Mini M1 to run the tests, but I got an error ( |
There was some stuff at SciPy about quansight being able to give out free access to NVIDIA GPUs for scientific python projects to use in CI. But this seems like a much later concern only for once it works on CPU and seems useful. |
e78cd21
to
2f7c324
Compare
Looks like this is getting close @alxmrs. There are still a few places where you've changed the tests to have a lower precision dtype, where we should also test that it works when it's left at the default. Can we merge once they are resolved? |
That sounds good to me Tom. Since my development time is sporadic and limited, I'll try to make the Jax features I work on independent from each other from here on out. Today, I extracted #536; since this is a bigger / flakier PR, I'll probably need a few more sessions to get it to fully land. |
I'm extracting cubed-dev#508 into smaller bites.
I'm extracting #508 into smaller bites.
Hey Tom! I should have mentioned this earlier. Can you run the workflows again? I think this PR is ready. |
The Dask test failure is a flaky test (#549), but the Array API test failures look like they are real. |
Inlined in order to not violate API boundaries. Trying to put this in a good place ends up leading to a circular import issue.
@alxmrs thanks for pushing this forward! Do you think it might make things easier to reduce the scope by e.g. targeting JAX on CPU and focusing on the device/dtype inspection stuff. Or maybe there's another way of splitting things up? I'd be happy to merge smaller PRs! |
I totally agree — at this point, it’s best for me to split up this large
CL. Yeah, targeting JAX on CPU while respecting default dtypes is a good
idea for a first cut.
ETA: I’ll start on this after focusing on landing std/var.
…On Tue, Sep 24, 2024 at 2:04 AM Tom White ***@***.***> wrote:
@alxmrs <https://github.com/alxmrs> thanks for pushing this forward! Do
you think it might make things easier to reduce the scope by e.g. targeting
JAX on CPU and focusing on the device/dtype inspection stuff
<https://data-apis.org/array-api/latest/API_specification/inspection.html>
.
Or maybe there's another way of splitting things up? I'd be happy to merge
smaller PRs!
—
Reply to this email directly, view it on GitHub
<#508 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AARXAB5K6PAZKXRODOMPXQ3ZYBCY3AVCNFSM6AAAAABLGHA23GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGNRYHA3DMNJUGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Great - thanks @alxmrs! |
I'm extracting cubed-dev#508 into smaller bites.
I'm beginning to explore #304 in greater depth. Since the only local GPU I have access to is an M1 chip (I have an M1 Macbook Air), I thought I would replicate this environment in CI.