Add JAX Frontend Support to Ivy Transpiler #28846
Labels
JAX Frontend
Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist
ToDo
A ToDo list of tasks
Transpiler
Anything related to transpiling
Description:
The current implementation of
ivy.transpile
supports"torch"
as the solesource
argument. This allows transpiling PyTorch functions or classes to target frameworks like TensorFlow, JAX, or NumPy. This task aims to extend the functionality by adding JAX as a validsource
, enabling transpilation of JAX code to other frameworks via Ivy's intermediate representation.For example, after completing this task, we should be able to transpile JAX code using:
Goals:
The main objective is to implement the first two stages of the transpilation pipeline for JAX:
Once these stages are complete, the rest of the pipeline can be reused to target other frameworks like TensorFlow, PyTorch, or NumPy. The steps would look as follows:
This mirrors the existing pipeline for PyTorch:
Key Tasks:
Add Native Framework-Specific Implementations for Core Transformation Passes:
native_jax_recursive_transformer.py
for traversing and transforming JAX native source code.native_torch_recursive_transformer.py
as a reference (example here)Define the Transformation Pipeline for JAX to JAX Frontend IR:
source_to_frontend_translator_config.py
to handle the stagesource='jax', target='jax_frontend'
(example here).Define the Transformation Pipeline for JAX Frontend IR to Ivy:
frontend_to_ivy_translator_config.py
to handle the stagesource='jax_frontend', target='ivy'
(example here).Add Stateful Classes for Flax APIs:
flax.nnx.Module
API that inherit fromivy.Module
.nn.Module
(example here)Understand and Leverage Reusability:
Testing:
Additional Notes:
nnx.Module
/nnx,Variable
.The text was updated successfully, but these errors were encountered: