what about JAX? #249
Replies: 4 comments
-
It has been discussed here #254 and I think would be an amazing addition as JAX offers some extremely useful functionality not really available in PyTorch. JAX does offer a callback function, and also lower level custom_call function. Thus, I definitely think it is possible. As I mentioned, somebody has already written a wrapper around Triton language (https://github.com/jax-ml/jax-triton) for a similar reason, but if you look at the code for the triton integration, it is clear is not as straightforward in the way building a bridge between things such as c++ code and PyTorch / Tensorflow are. |
Beta Was this translation helpful? Give feedback.
-
Toy package containing boilerplate for writing custom CUDA kernels for JAX. |
Beta Was this translation helpful? Give feedback.
-
FYI, this might provide the basis of a solution in order to support Keops with JAX, |
Beta Was this translation helpful? Give feedback.
-
An alternative torch2jax that unlike (rdyro one) doesn't build a custom operator, instead it uses abstract interpretation (aka tracing) to move JAX values through PyTorch code. As a result, you get a JAX-native computation graph that follows exactly your PyTorch code, down to the last epsilon. |
Beta Was this translation helpful? Give feedback.
-
Hello
Your lib sounds amazing. Have you any plan to make a JAX compatible bridge as friends are building JAX pipeline it would be great to take benefit of your impressive work!
Thank.
Beta Was this translation helpful? Give feedback.
All reactions