M1 jax PR (for CI only) #113
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: JAX tests | ||
on: | ||
pull_request: | ||
schedule: | ||
# Every weekday at 03:53 UTC, see https://crontab.guru/ | ||
- cron: "53 3 * * 1-5" | ||
workflow_dispatch: | ||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
jobs: | ||
test: | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
# How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/ | ||
os: ["ubuntu-latest", "macos-14"] | ||
python-version: ["3.11"] | ||
precision: ["64", "32"] | ||
steps: | ||
- name: Checkout source | ||
uses: actions/checkout@v3 | ||
with: | ||
fetch-depth: 0 | ||
- name: Set up Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
architecture: x64 | ||
- name: Setup Graphviz | ||
uses: ts-graphviz/setup-graphviz@v2 | ||
- name: Install | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install -e '.[test-jax]' | ||
# Verify jax | ||
python -c 'import jax; print(jax.numpy.arange(10))' | ||
- name: Run tests | ||
run: | | ||
# exclude a few tests that don't work on JAX | ||
pytest -k "not broadcast_trick and not object_dtype" | ||
env: | ||
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy | ||
JAX_ENABLE_X64: ${{ matrix.precision == "64" }} | ||
Check failure on line 53 in .github/workflows/jax-tests.yml GitHub Actions / JAX testsInvalid workflow file
|
||
CUBED_DEFAULT_PRECISION_X32: ${{ matrix.precision == "32" }} | ||
ENABLE_PJRT_COMPATIBILITY: True |