Skip to content

Commit

Permalink
Upate README.md (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
feng-intel authored Jun 26, 2023
1 parent 15b0133 commit 837cf64
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 8 deletions.
120 changes: 112 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,121 @@
Default build cmd is:
```
bazel build //xla/tools/pip_package:build_pip_package
```

The [OpenXLA](https://github.com/openxla/xla) Project brings together a community of developers and leading AI/ML teams to accelerate ML and address infrastructure fragmentation across ML frameworks and hardware.

Intel® Extension for OpenXLA includes PJRT plugin implementation, which seamlessly runs JAX models on Intel GPU. The PJRT API simplified the integration, which allowed the Intel GPU plugin to be developed separately and quickly integrated into JAX. This same PJRT implementation also enables initial Intel GPU support for TensorFlow and PyTorch models with XLA acceleration. Refer to [OpenXLA PJRT Plugin RFC](https://github.com/openxla/community/blob/main/rfcs/20230123-pjrt-plugin.md) for more details.

This guide introduces the overview of OpenXLA high level integration structure and demonstrates how to build Intel® Extension for OpenXLA and run JAX example with OpenXLA on Intel GPU. JAX is the first supported front-end.

## 1. Overview
<p align="center">
<img src="openxla_for_intel_gpu.jpg" width="50%">
</p>

* [JAX](https://jax.readthedocs.io/en/latest/) provides a familiar NumPy-style API, includes composable function transformations for compilation, batching, automatic differentiation, and parallelization, and the same code executes on multiple backends.
* TensorFlow and PyTorch support is on the way.

## 2. Hardware and Software Requirement

### Hardware Requirements

Verified Hardware Platforms:
- Intel® Data Center GPU Max Series, Driver Version: [602](https://dgpu-docs.intel.com/releases/stable_602_20230323.html)
- Intel® Data Center GPU Flex Series 170, Driver Version: [602](https://dgpu-docs.intel.com/releases/stable_602_20230323.html)
- *Experimental:* Intel® Arc™ A-Series

### Software Requirements
- Ubuntu 22.04, Red Hat 8.6 (64-bit)
- Intel® Data Center GPU Flex Series
- Ubuntu 22.04, Red Hat 8.6 (64-bit), SUSE Linux Enterprise Server(SLES) 15 SP3/SP4
- Intel® Data Center GPU Max Series
- Intel® oneAPI Base Toolkit 2023.1
- Jax/Jaxlib 0.4.7
- Python 3.8-3.10
- pip 19.0 or later (requires manylinux2014 support)


### Install GPU Drivers

|Release|OS|Intel GPU|Install Intel GPU Driver|
|-|-|-|-|
|v1.2.0|Ubuntu 22.04, Red Hat 8.6|Intel® Data Center GPU Flex Series| Refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-flex-series) for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series [602](https://dgpu-docs.intel.com/releases/stable_602_20230323.html), please append the specific version after components, such as `sudo apt-get install intel-opencl-icd==23.05.25593.18-601~22.04`|
|v1.2.0|Ubuntu 22.04, Red Hat 8.6, SLES 15 SP3/SP4|Intel® Data Center GPU Max Series| Refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-max-series) for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series [602](https://dgpu-docs.intel.com/releases/stable_602_20230323.html), please append the specific version after components, such as `sudo apt-get install intel-opencl-icd==23.05.25593.18-601~22.04`|

## Build and Install
```bash
# Source OneAPI env
$ source /opt/intel/oneapi/compiler/latest/env/vars.sh
$ source /opt/intel/oneapi/mkl/latest/env/vars.sh
$ source /opt/intel/oneapi/tbb/latest/env/vars.sh

$ git clone https://github.com/intel/intel-extension-for-openxla.git
$ pip install jax==0.4.7 jaxlib==0.4.7
$ ./configure # Choose Yes for all.
$ bazel build //xla/tools/pip_package:build_pip_package
$ ./bazel-bin/xla/tools/pip_package/build_pip_package ./
$ pip install intel_extension_for_openxla-0.1.0-cp39-cp39-linux_x86_64.whl
```
This repo pulls public openxla code as its third_party. For development, one often wants to make changes to the XLA repository as well. You can override the pinned xla repo with a local checkout by:
```
bazel build --override_repository=xla=/path/to/xla //xla/tools/pip_package:build_pip_package
```

Then, generate wheel and install it.
**Notes**:
* This project won't release any whl or .so library, only source code, so this "build and install" is only for testing purpose.
* Besides python whl, we can also build .so `bazel build //xla:libitex_xla_extension.so` and run with ENV `PJRT_NAMES_AND_LIBRARY_PATHS`, the same as https://intel.github.io/intel-extension-for-tensorflow/latest/docs/guide/OpenXLA_Support_on_GPU.html

## 4. Run JAX Example

* **Run the below jax python code.**
When running jax code, please `import intel_extension_for_openxla`, otherwise "XPU" device can not be detected. `jax.local_devices()` can check which device is running.
```python
import jax
import jax.numpy as jnp
import intel_extension_for_openxla
import jax
print("jax.local_devices(): ", jax.local_devices())

@jax.jit
def lax_conv():
key = jax.random.PRNGKey(0)
lhs = jax.random.uniform(key, (2,1,9,9), jnp.float32)
rhs = jax.random.uniform(key, (1,1,4,4), jnp.float32)
side = jax.random.uniform(key, (1,1,1,1), jnp.float32)
out = jax.lax.conv_with_general_padding(lhs, rhs, (1,1), ((0,0),(0,0)), (1,1), (1,1))
out = jax.nn.relu(out)
out = jnp.multiply(out, side)
return out

print(lax_conv())
```
bazel-bin/xla/tools/pip_package/build_pip_package ./
pip install ./intel_extension_for_openxla-0.1.0-cp39-cp39-linux_x86_64.whl
* **Reference result:**
```
jax.local_devices(): [IntelXpuDevice(id=0, process_index=0)]
[[[[2.0449753 2.093208 2.1844783 1.9769732 1.5857391 1.6942389]
[1.9218378 2.2862523 2.1549542 1.8367321 1.3978379 1.3860377]
[1.9456574 2.062028 2.0365305 1.901286 1.5255247 1.1421617]
[2.0621 2.2933435 2.1257985 2.1095486 1.5584903 1.1229166]
[1.7746235 2.2446113 1.7870374 1.8216239 1.557919 0.9832508]
[2.0887792 2.5433128 1.9749291 2.2580051 1.6096935 1.264905 ]]]
[[[2.175818 2.0094342 2.005763 1.6559253 1.3896458 1.4036925]
[2.1342552 1.8239582 1.6091168 1.434404 1.671778 1.7397764]
[1.930626 1.659667 1.6508744 1.3305787 1.4061482 2.0829628]
[2.130649 1.6637266 1.594426 1.2636002 1.7168686 1.8598001]
[1.9009514 1.7938274 1.4870623 1.6193901 1.5297288 2.0247464]
[2.0905268 1.7598859 1.9362347 1.9513799 1.9403584 2.1483061]]]]
```

## 5. FAQ

When running jax code, pls `import intel_extension_for_openxla`, otherwise "XPU" device can not be detected.
* **Q**: It can't detect xpu device.
**A**: Don't forget `import intel_extension_for_openxla` in jax code.
Print `jax.local_devices()` to check which device is running.
`export OCL_ICD_ENABLE_TRACE=1` to checks if it has driver error log.
`export ONEDNN_VERBOSE=2` It shows detailed OneDNN execution info.
The below code open more debug log for JAX app.
```python
import logging
logging.basicConfig(level = logging.DEBUG)
```
* **Q**: There is the error 'version GLIBCXX_3.4.30' not found.
**A**: please upgrade GCC to latest, for example for conda
```$ conda install libstdcxx-ng==12.2.0 -c conda-forge```
Binary file added openxla_for_intel_gpu.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 837cf64

Please sign in to comment.