Skip to content

M1 jax PR (for CI only) #113

M1 jax PR (for CI only)

M1 jax PR (for CI only) #113

Workflow file for this run

name: JAX tests
on:
pull_request:
schedule:
# Every weekday at 03:53 UTC, see https://crontab.guru/
- cron: "53 3 * * 1-5"
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/
os: ["ubuntu-latest", "macos-14"]
python-version: ["3.11"]
precision: ["64", "32"]
steps:
- name: Checkout source
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64
- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v2
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test-jax]'
# Verify jax
python -c 'import jax; print(jax.numpy.arange(10))'
- name: Run tests
run: |
# exclude a few tests that don't work on JAX
pytest -k "not broadcast_trick and not object_dtype"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: ${{ matrix.precision == "64" }}

Check failure on line 53 in .github/workflows/jax-tests.yml

View workflow run for this annotation

GitHub Actions / JAX tests

Invalid workflow file

The workflow is not valid. .github/workflows/jax-tests.yml (Line: 53, Col: 27): Unexpected symbol: '"64"'. Located at position 21 within expression: matrix.precision == "64" .github/workflows/jax-tests.yml (Line: 54, Col: 40): Unexpected symbol: '"32"'. Located at position 21 within expression: matrix.precision == "32"
CUBED_DEFAULT_PRECISION_X32: ${{ matrix.precision == "32" }}
ENABLE_PJRT_COMPATIBILITY: True