From f848febacdc54c351ed0ed23fcc4c9349828021e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 18 Aug 2024 08:47:26 +0530 Subject: [PATCH] feat: allow sharding for auraflow. (#8853) --- src/diffusers/models/transformers/auraflow_transformer_2d.py | 1 + tests/models/transformers/test_models_transformer_aura_flow.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 0b20c4818a02..ad64df0c0790 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. """ + _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] _supports_gradient_checkpointing = True @register_to_config diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py index 57fac4ba769c..51075b2b4cc1 100644 --- a/tests/models/transformers/test_models_transformer_aura_flow.py +++ b/tests/models/transformers/test_models_transformer_aura_flow.py @@ -29,6 +29,8 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): model_class = AuraFlowTransformer2DModel main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.7, 0.6, 0.6] @property def dummy_input(self):