-
Notifications
You must be signed in to change notification settings - Fork 6.1k
torch.compile compatibility with varlen APIs #11970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
As a reference, using the pytorch native attention (the default):
import torch
import torch._inductor.utils
from diffusers import FluxPipeline, FluxTransformer2DModel
model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")
# pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=True, dynamic=False)
pipe.transformer.compile_repeated_blocks(fullgraph=True, dynamic=False)
prompt = "A cat holding a sign that says 'hello world'"
with torch._inductor.utils.fresh_inductor_cache():
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png") cc @sayakpaul @anijain2305 is this expected? |
I thought #11916 might be the culprit, so I rechecked the compilation tests again. They are passing for me, so all looks good wrt tests. We will probably need some other way of testing why there are multiple recompilations (if it's unintended), and look into how to fix this compile test logs(nightly-venv) aryan@hf-dgx-01:~/work/diffusers$ RUN_COMPILE=1 RUN_SLOW=1 pytest -s tests/models/transformers/te
st_models_transformer_flux.py -k "torch_compile"
============================================= test session starts ==============================================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/aryan/work/diffusers
configfile: pyproject.toml
plugins: timeout-2.3.1, requests-mock-1.10.0, xdist-3.6.1, hydra-core-1.3.2, anyio-4.6.2.post1
collected 79 items / 77 deselected / 2 selected
tests/models/transformers/test_models_transformer_flux.py ..
=============================================== warnings summary ===============================================
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
/raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_torch_compile_recompilation_and_graph_break
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:236: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================ 2 passed, 77 deselected, 5 warnings in 32.12s ================================= |
What's the pytorch version? |
Oh, I thought I shared 😅 I've updated the description now (pytorch is 2.7.1 stable cu126) |
Ran import torch
import torch._inductor.utils
from diffusers import FluxPipeline
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to("cuda")
# pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=True, dynamic=False)
pipe.transformer.compile_repeated_blocks(fullgraph=True)
prompt = "A cat holding a sign that says 'hello world'"
with torch._inductor.utils.fresh_inductor_cache():
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png") With torch The tests invoke the However, I would be interested to know more about why we may need Flash with varlen because Flux doesn't use masks and the advantages we get by using varlens. |
varlen with compile failed because of the call to .item(). We need to make sure each attention provider works with each model where applicable. Fixing the behaviour with compile here will fix it for other models that do use masks (maybe HunyuanVideo IIRC) - we're using Flux for the tests because it's the only model currently supported with the attention dispatcher |
Ah thanks for explaining! But I guess the issue you reported in #11970 (comment) is now green with torch nightly? Happy to investigate compilation issues further and/or design tests. |
Looks into #11957.
Some of the change related to key_valid/value_valid are removed. Removing padding tokens was probably incorrect (because of how the max_seq_lens and cu_seq_lens interact) and not required (because we know that a batch of data will not have varying sequence lengths).
I'm running into some other errors after using the flags suggested by @StrongerXi (the flags fix the initial error of
.item()
failing, but reveal other errors when running withfullgraph=True
)If using
torch.compile
withflash
backend (which is FA2 from source): ✅If using
compile_repeated_blocks
withflash
backend (multiple recompiles on attention processor): ❌Stack trace
This makes sense because each block is making a call to a "different" function (processor object). Not sure how to fix this
If using
flash_varlen
as backend: ✅If using
flash_varlen
as backend withtorch.compile
: ❌stack trace
FA2 requires cu_seq_lens to be
torch.int32
but instead we end up havingtorch.int64
. It looks like inductor is generating incorrect code by optimizing out the cast totorch.int32
.relevant part of inductor code
flash_varlen
as backend withcompile_repeated_blocks
: ❌Same as above case.
Note: wherever I mentioned
torch.compile
to be used, it means I tested withfullgraph=True
anddynamic=False
Environment