Skip to content

Commit 76c00c7

Browse files
authored
is_safetensors_compatible fix (#9741)
update
1 parent 0d9d98f commit 76c00c7

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
118118
components.setdefault(component, [])
119119
components[component].append(component_filename)
120120

121+
# If there are no component folders check the main directory for safetensors files
122+
if not components:
123+
return any(".safetensors" in filename for filename in filenames)
124+
121125
# iterate over all files of a component
122126
# check if safetensor files exist for that component
123127
# if variant is provided check if the variant of the safetensors exists

tests/pipelines/test_pipeline_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,18 @@ def test_diffusers_is_compatible_only_variants(self):
197197
]
198198
self.assertTrue(is_safetensors_compatible(filenames))
199199

200+
def test_diffusers_is_compatible_no_components(self):
201+
filenames = [
202+
"diffusion_pytorch_model.bin",
203+
]
204+
self.assertFalse(is_safetensors_compatible(filenames))
205+
206+
def test_diffusers_is_compatible_no_components_only_variants(self):
207+
filenames = [
208+
"diffusion_pytorch_model.fp16.bin",
209+
]
210+
self.assertFalse(is_safetensors_compatible(filenames))
211+
200212

201213
class ProgressBarTests(unittest.TestCase):
202214
def get_dummy_components_image_generation(self):

0 commit comments

Comments
 (0)