add reshape to fix use_memory_efficient_attention in flax (#7918) #252
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Fast mps tests on main | |
on: | |
push: | |
branches: | |
- main | |
paths: | |
- "src/diffusers/**.py" | |
- "tests/**.py" | |
env: | |
DIFFUSERS_IS_CI: yes | |
HF_HOME: /mnt/cache | |
OMP_NUM_THREADS: 8 | |
MKL_NUM_THREADS: 8 | |
HF_HUB_ENABLE_HF_TRANSFER: 1 | |
PYTEST_TIMEOUT: 600 | |
RUN_SLOW: no | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
cancel-in-progress: true | |
jobs: | |
run_fast_tests_apple_m1: | |
name: Fast PyTorch MPS tests on MacOS | |
runs-on: macos-13-xlarge | |
steps: | |
- name: Checkout diffusers | |
uses: actions/checkout@v3 | |
with: | |
fetch-depth: 2 | |
- name: Clean checkout | |
shell: arch -arch arm64 bash {0} | |
run: | | |
git clean -fxd | |
- name: Setup miniconda | |
uses: ./.github/actions/setup-miniconda | |
with: | |
python-version: 3.9 | |
- name: Install dependencies | |
shell: arch -arch arm64 bash {0} | |
run: | | |
${CONDA_RUN} python -m pip install --upgrade pip uv | |
${CONDA_RUN} python -m uv pip install -e [quality,test] | |
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio | |
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git | |
${CONDA_RUN} python -m uv pip install transformers --upgrade | |
- name: Environment | |
shell: arch -arch arm64 bash {0} | |
run: | | |
${CONDA_RUN} python utils/print_env.py | |
- name: Run fast PyTorch tests on M1 (MPS) | |
shell: arch -arch arm64 bash {0} | |
env: | |
HF_HOME: /System/Volumes/Data/mnt/cache | |
HF_TOKEN: ${{ secrets.HF_TOKEN }} | |
run: | | |
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ | |
- name: Failure short reports | |
if: ${{ failure() }} | |
run: cat reports/tests_torch_mps_failures_short.txt | |
- name: Test suite reports artifacts | |
if: ${{ always() }} | |
uses: actions/upload-artifact@v4 | |
with: | |
name: pr_torch_mps_test_reports | |
path: reports |