-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optional autodiff support? #518
Comments
Cubed has a simple model of upper bounds for memory usage, derived from knowledge about the different operations in the array API. So if there's a way of modelling the memory usage of gradient operations, then this should be possible. |
So if there's a way of modelling the memory usage of gradient operations,
then this should be possible
I’ve been exploring the jax docs for an answer and I have two ideas so far.
Option one: we can take advantage of the existing jit (or grad?) mechanics
to extract array shape information ahead of time (via tracers).
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables
I suspect that grad will produce a trace that will have the array shape and
type information. This should provide enough to create memory bounds for
Cubed.
Option two: we could create a memory profiling tracer.
https://jax.readthedocs.io/en/latest/autodidax.html
Tracers are Jax-style visitors. I think if we created a generic memory
tracer, we could probably use it on grad and non-grad jax programs.
https://github.com/google/jax/blob/694c14bbe6e365f543c7dc67114c8c5e67b5c2df/jax/_src/core.py#L512
Maybe it would be implemented as an AbstractTracer, though grad is
concrete. Hmm…
https://jax.readthedocs.io/en/latest/faq.html#different-kinds-of-jax-values
https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array
…On Wed, Jul 24, 2024 at 9:35 AM Tom White ***@***.***> wrote:
Can Cubed’s spec model be extended here?
Cubed has a simple model of upper bounds for memory usage, derived from
knowledge about the different operations in the array API. So if there's a
way of modelling the memory usage of gradient operations, then this should
be possible.
—
Reply to this email directly, view it on GitHub
<#518 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AARXAB555WVPFCFRIH2GGJLZN5RNRAVCNFSM6AAAAABLKFK54WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBXGIZDSNJQGU>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
If you want to do this on top of JAX, I think the easiest approach is probably to build a custom interpreter that implements the array API on top of JAXprs: JAX primitives are typically thin wrappers around a minimal set of XLA operations, so hopefully this would be relatively straightforward. JAXprs are quite explicit about array shapes, so memory usage should be fairly transparent. |
Thanks @shoyer! That's very useful. |
It would be awesome if the backing array implementation supported auto differentiation, that we could access some
grad
method from Cubed.It looks like a bunch of stakeholder libraries have this functionality:
https://data-apis.org/array-api/latest/purpose_and_scope.html#stakeholders
Though, differentiable programming may be out of scope for Cubed. @TomNicholas @tomwhite @rbavery any thoughts here?
I have a pipe dream of turning Cubed into an ML framework, and I think this would play an important part.
I haven’t thought of all the implications, but a potential sharp edge that @shoyer once pointed out to me: there will probably be significant memory differences between an op graph and its gradient. Can Cubed’s spec model be extended here?
The text was updated successfully, but these errors were encountered: