diff --git a/.github/workflows/jax-tests.yml b/.github/workflows/jax-tests.yml index 730084e1..3332ac3e 100644 --- a/.github/workflows/jax-tests.yml +++ b/.github/workflows/jax-tests.yml @@ -20,6 +20,7 @@ jobs: # How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/ os: ["ubuntu-latest", "macos-14"] python-version: ["3.11"] + precision: ["64", "32"] steps: - name: Checkout source @@ -51,5 +52,6 @@ jobs: pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization" env: CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy - JAX_ENABLE_X64: True + JAX_ENABLE_X64: ${{ matrix.precision == "64" }} + CUBED_DEFAULT_PRECISION_X32: ${{ matrix.precision == "32" }} ENABLE_PJRT_COMPATIBILITY: True