Skip to content

torch.compile compatibility with varlen APIs #11970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 21, 2025

Looks into #11957.

Some of the change related to key_valid/value_valid are removed. Removing padding tokens was probably incorrect (because of how the max_seq_lens and cu_seq_lens interact) and not required (because we know that a batch of data will not have varying sequence lengths).

I'm running into some other errors after using the flags suggested by @StrongerXi (the flags fix the initial error of .item() failing, but reveal other errors when running with fullgraph=True)

import torch
import torch._dynamo.config
from diffusers import FluxPipeline, FluxTransformer2DModel

torch._dynamo.config.capture_scalar_outputs = True
# Recompilations because of attention processor id mismatch
torch._dynamo.config.cache_size_limit = 100

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")

# pipe.transformer.set_attention_backend("flash")
pipe.transformer.set_attention_backend("flash_varlen")

# pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=True, dynamic=False)
pipe.transformer.compile_repeated_blocks(fullgraph=True, dynamic=False)

prompt = "A cat holding a sign that says 'hello world'"
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]

image.save("output.png")
  • If using torch.compile with flash backend (which is FA2 from source): ✅

  • If using compile_repeated_blocks with flash backend (multiple recompiles on attention processor): ❌

Stack trace
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles] Recompiling function forward in /home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_flux.py:381
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     triggered by the following guard failure(s):
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/36: ___check_obj_id(self._modules['attn'].processor, 140462015041104)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/35: ___check_obj_id(self._modules['attn'].processor, 140462015040240)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/34: ___check_obj_id(self._modules['attn'].processor, 140462015039376)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/33: ___check_obj_id(self._modules['attn'].processor, 140462015038512)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/32: ___check_obj_id(self._modules['attn'].processor, 140462015037648)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/31: ___check_obj_id(self._modules['attn'].processor, 140466393496880)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/30: ___check_obj_id(self._modules['attn'].processor, 140466393496016)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/29: ___check_obj_id(self._modules['attn'].processor, 140466393495152)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/28: ___check_obj_id(self._modules['attn'].processor, 140466393494288)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/27: ___check_obj_id(self._modules['attn'].processor, 140466393493424)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/26: ___check_obj_id(self._modules['attn'].processor, 140466393492560)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/25: ___check_obj_id(self._modules['attn'].processor, 140466393491696)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/24: ___check_obj_id(self._modules['attn'].processor, 140466393490832)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/23: ___check_obj_id(self._modules['attn'].processor, 140466393489968)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/22: ___check_obj_id(self._modules['attn'].processor, 140466393489104)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/21: ___check_obj_id(self._modules['attn'].processor, 140466393488240)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/20: ___check_obj_id(self._modules['attn'].processor, 140466393487376)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/19: ___check_obj_id(self._modules['attn'].processor, 140466393486512)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/18: ___check_obj_id(self._modules['attn'].processor, 140466393485648)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/17: ___check_obj_id(self._modules['attn'].processor, 140466393484784)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/16: ___check_obj_id(self._modules['attn'].processor, 140466393483920)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/15: ___check_obj_id(self._modules['attn'].processor, 140466393483056)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/14: ___check_obj_id(self._modules['attn'].processor, 140466393482192)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/13: ___check_obj_id(self._modules['attn'].processor, 140466393481328)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/12: ___check_obj_id(self._modules['attn'].processor, 140466392857808)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/11: ___check_obj_id(self._modules['attn'].processor, 140466392856944)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/10: ___check_obj_id(self._modules['attn'].processor, 140466392856080)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/9: ___check_obj_id(self._modules['attn'].processor, 140466392855216)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/8: ___check_obj_id(self._modules['attn'].processor, 140466392854352)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/7: ___check_obj_id(self._modules['attn'].processor, 140466392853488)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/6: ___check_obj_id(self._modules['attn'].processor, 140466392852624)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/5: ___check_obj_id(self._modules['attn'].processor, 140466392851760)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/4: ___check_obj_id(self._modules['attn'].processor, 140466392850896)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/3: ___check_obj_id(self._modules['attn'].processor, 140466392850032)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/2: ___check_obj_id(self._modules['attn'].processor, 140466392849168)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/1: ___check_obj_id(self._modules['attn'].processor, 140466392848304)
V0721 23:40:59.397000 2159156 torch/_dynamo/guards.py:3006] [1/37] [__recompiles]     - 1/0: ___check_obj_id(self._modules['attn'].processor, 140466392847440)

This makes sense because each block is making a call to a "different" function (processor object). Not sure how to fix this

  • If using flash_varlen as backend: ✅

  • If using flash_varlen as backend with torch.compile: ❌

stack trace
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump19.py", line 753, in <module>
    image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 918, in __call__
    noise_pred = self.transformer(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 734, in forward
    encoder_hidden_states, hidden_states = block(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 441, in forward
    def forward(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1209, in forward
    return compiled_fn(full_args)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 328, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn
    outs = compiled_fn(args)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper
    return compiled_fn(runtime_args)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 460, in __call__
    return self.current_callable(inputs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2404, in run
    return model(new_inputs)
  File "/tmp/torchinductor_aryan/yj/cyjfbeabktf6xelcncgrwo5ggvp4ip2rlggvtozudr43y43udth3.py", line 1297, in call
    buf33 = torch.ops.flash_attn._flash_attn_varlen_forward.default(buf28, buf29, buf30, buf31, buf32, u0, u1, 0.0, 0.08838834764831845, False)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_library/autograd.py", line 113, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_ops.py", line 761, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 335, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 367, in wrapped_fn
    return fn(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 170, in _flash_attn_varlen_forward
    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
RuntimeError: cu_seqlens_q must have dtype int32

FA2 requires cu_seq_lens to be torch.int32 but instead we end up having torch.int64. It looks like inductor is generating incorrect code by optimizing out the cast to torch.int32.

relevant part of inductor code
   with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf28 = empty_strided_cuda((4608, 24, 128), (3072, 128, 1), torch.bfloat16)
        buf29 = empty_strided_cuda((4608, 24, 128), (3072, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [cumsum, cu_seqlens_q, cumsum_1, cu_seqlens_k, _flash_attn_varlen_forward], Original ATen: [aten.cumsum, aten.constant_pad_nd, flash_attn._flash_attn_varlen_forward]
        stream0 = get_raw_stream(0)
        triton_poi_fused__flash_attn_varlen_forward_constant_pad_nd_cumsum_7.run(buf16, arg23_1, arg24_1, buf21, buf28, buf29, 14155776, stream=stream0)
        del arg23_1
        del arg24_1
        del buf16
        buf30 = reinterpret_tensor(buf21, (4608, 24, 128), (3072, 128, 1), 0); del buf21  # reuse
        # Topologically Sorted Source Nodes: [cumsum, cu_seqlens_q, cumsum_1, cu_seqlens_k, _flash_attn_varlen_forward], Original ATen: [aten.cumsum, aten.constant_pad_nd, flash_attn._flash_attn_varlen_forward]
        stream0 = get_raw_stream(0)
        triton_poi_fused__flash_attn_varlen_forward_constant_pad_nd_cumsum_8.run(buf22, buf23, buf30, 14155776, stream=stream0)
        buf31 = empty_strided_cuda((2, ), (1, ), torch.int64)
        # Topologically Sorted Source Nodes: [cumsum, cu_seqlens_q, cumsum_1, cu_seqlens_k, _flash_attn_varlen_forward], Original ATen: [aten.cumsum, aten.constant_pad_nd, flash_attn._flash_attn_varlen_forward]
        stream0 = get_raw_stream(0)
        triton_poi_fused__flash_attn_varlen_forward_constant_pad_nd_cumsum_9.run(buf31, 2, stream=stream0)
        buf32 = empty_strided_cuda((2, ), (1, ), torch.int64)
        # Topologically Sorted Source Nodes: [cumsum, cu_seqlens_q, cumsum_1, cu_seqlens_k, _flash_attn_varlen_forward], Original ATen: [aten.cumsum, aten.constant_pad_nd, flash_attn._flash_attn_varlen_forward]
        stream0 = get_raw_stream(0)
        triton_poi_fused__flash_attn_varlen_forward_constant_pad_nd_cumsum_9.run(buf32, 2, stream=stream0)
        # Topologically Sorted Source Nodes: [cumsum, cu_seqlens_q, cumsum_1, cu_seqlens_k, _flash_attn_varlen_forward], Original ATen: [aten.cumsum, aten.constant_pad_nd, flash_attn._flash_attn_varlen_forward]
        buf33 = torch.ops.flash_attn._flash_attn_varlen_forward.default(buf28, buf29, buf30, buf31, buf32, u0, u1, 0.0, 0.08838834764831845, False)
  • If using flash_varlen as backend with compile_repeated_blocks: ❌

Same as above case.

Note: wherever I mentioned torch.compile to be used, it means I tested with fullgraph=True and dynamic=False

Environment
- 🤗 Diffusers version: 0.35.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.7.1+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
- Jax version: 0.4.31
- JaxLib version: 0.4.31
- Huggingface_hub version: 0.32.5
- Transformers version: 4.52.3
- Accelerate version: 1.9.0.dev0
- PEFT version: 0.15.2
- Bitsandbytes version: 0.46.0
- Safetensors version: 0.5.3
- xFormers version: 0.0.30
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB

@a-r-r-o-w a-r-r-o-w requested a review from DN6 July 21, 2025 23:06
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member Author

As a reference, using the pytorch native attention (the default):

  • Full compile: ✅
  • Compiled Regions: ❌ (fails with recompilation errors if limit is not increased)
import torch
import torch._inductor.utils
from diffusers import FluxPipeline, FluxTransformer2DModel

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")

# pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=True, dynamic=False)
pipe.transformer.compile_repeated_blocks(fullgraph=True, dynamic=False)

prompt = "A cat holding a sign that says 'hello world'"
with torch._inductor.utils.fresh_inductor_cache():
    image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]

image.save("output.png")

cc @sayakpaul @anijain2305 is this expected?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 21, 2025

I thought #11916 might be the culprit, so I rechecked the compilation tests again. They are passing for me, so all looks good wrt tests. We will probably need some other way of testing why there are multiple recompilations (if it's unintended), and look into how to fix this

compile test logs
(nightly-venv) aryan@hf-dgx-01:~/work/diffusers$ RUN_COMPILE=1 RUN_SLOW=1 pytest -s tests/models/transformers/te
st_models_transformer_flux.py -k "torch_compile"
============================================= test session starts ==============================================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/aryan/work/diffusers
configfile: pyproject.toml
plugins: timeout-2.3.1, requests-mock-1.10.0, xdist-3.6.1, hydra-core-1.3.2, anyio-4.6.2.post1
collected 79 items / 77 deselected / 2 selected                                                                

tests/models/transformers/test_models_transformer_flux.py ..

=============================================== warnings summary ===============================================
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
  /raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
    warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_torch_compile_recompilation_and_graph_break
  /raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:236: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================ 2 passed, 77 deselected, 5 warnings in 32.12s =================================

@StrongerXi
Copy link

What's the pytorch version?

@a-r-r-o-w
Copy link
Member Author

Oh, I thought I shared 😅 I've updated the description now (pytorch is 2.7.1 stable cu126)

@sayakpaul
Copy link
Member

sayakpaul commented Jul 22, 2025

Ran

import torch
import torch._inductor.utils
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to("cuda")

# pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=True, dynamic=False)
pipe.transformer.compile_repeated_blocks(fullgraph=True)

prompt = "A cat holding a sign that says 'hello world'"
with torch._inductor.utils.fresh_inductor_cache():
    image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]

image.save("output.png")

With torch 2.8.0.dev20250604+cu126 and it worked.

The tests invoke the model twice to minimally a mimic the iterative denoising process. Maybe we could decorate the tests so that it becomes clear that they should be tested with a PyTorch nightly version (which is what @StrongerXi probably also hinted at).

However, I would be interested to know more about why we may need Flash with varlen because Flux doesn't use masks and the advantages we get by using varlens.

@a-r-r-o-w
Copy link
Member Author

varlen with compile failed because of the call to .item(). We need to make sure each attention provider works with each model where applicable. Fixing the behaviour with compile here will fix it for other models that do use masks (maybe HunyuanVideo IIRC) - we're using Flux for the tests because it's the only model currently supported with the attention dispatcher

@sayakpaul
Copy link
Member

Ah thanks for explaining! But I guess the issue you reported in #11970 (comment) is now green with torch nightly? Happy to investigate compilation issues further and/or design tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants