diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index a0af28803d79..af323164f562 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -89,49 +89,44 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: +def is_safetensors_compatible(filenames, passed_components=None) -> bool: """ Checking for safetensors compatibility: - - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch - files to know which safetensors files are needed. - - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. + - The model is safetensors compatible only if there is a safetensors file for each model component present in + filenames. Converting default pytorch serialized filenames to safetensors serialized filenames: - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" extension is replaced with ".safetensors" """ - pt_filenames = [] - - sf_filenames = set() - passed_components = passed_components or [] + # extract all components of the pipeline and their associated files + components = {} for filename in filenames: - _, extension = os.path.splitext(filename) + if not len(filename.split("/")) == 2: + continue - if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: + component, component_filename = filename.split("/") + if component in passed_components: continue - if extension == ".bin": - pt_filenames.append(os.path.normpath(filename)) - elif extension == ".safetensors": - sf_filenames.add(os.path.normpath(filename)) + components.setdefault(component, []) + components[component].append(component_filename) - for filename in pt_filenames: - # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' - path, filename = os.path.split(filename) - filename, extension = os.path.splitext(filename) + # iterate over all files of a component + # check if safetensor files exist for that component + # if variant is provided check if the variant of the safetensors exists + for component, component_filenames in components.items(): + matches = [] + for component_filename in component_filenames: + filename, extension = os.path.splitext(component_filename) - if filename.startswith("pytorch_model"): - filename = filename.replace("pytorch_model", "model") - else: - filename = filename + match_exists = extension == ".safetensors" + matches.append(match_exists) - expected_sf_filename = os.path.normpath(os.path.join(path, filename)) - expected_sf_filename = f"{expected_sf_filename}.safetensors" - if expected_sf_filename not in sf_filenames: - logger.warning(f"{expected_sf_filename} not found") + if not any(matches): return False return True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2cc9defc3ffa..f2882c5b1d02 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1416,18 +1416,14 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: if ( use_safetensors and not allow_pickle - and not is_safetensors_compatible( - model_filenames, variant=variant, passed_components=passed_components - ) + and not is_safetensors_compatible(model_filenames, passed_components=passed_components) ): raise EnvironmentError( f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif use_safetensors and is_safetensors_compatible( - model_filenames, variant=variant, passed_components=passed_components - ): + elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components): ignore_patterns = ["*.bin", "*.msgpack"] use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 51d987d8bb11..0e3f2e8c2e27 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -68,25 +68,21 @@ def test_all_is_compatible_variant(self): "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - variant = "fp16" - self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + self.assertTrue(is_safetensors_compatible(filenames)) def test_diffusers_model_is_compatible_variant(self): filenames = [ "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - variant = "fp16" - self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + self.assertTrue(is_safetensors_compatible(filenames)) - def test_diffusers_model_is_compatible_variant_partial(self): - # pass variant but use the non-variant filenames + def test_diffusers_model_is_compatible_variant_mixed(self): filenames = [ "unet/diffusion_pytorch_model.bin", - "unet/diffusion_pytorch_model.safetensors", + "unet/diffusion_pytorch_model.fp16.safetensors", ] - variant = "fp16" - self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + self.assertTrue(is_safetensors_compatible(filenames)) def test_diffusers_model_is_not_compatible_variant(self): filenames = [ @@ -99,25 +95,14 @@ def test_diffusers_model_is_not_compatible_variant(self): "unet/diffusion_pytorch_model.fp16.bin", # Removed: 'unet/diffusion_pytorch_model.fp16.safetensors', ] - variant = "fp16" - self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) + self.assertFalse(is_safetensors_compatible(filenames)) def test_transformer_model_is_compatible_variant(self): filenames = [ "text_encoder/pytorch_model.fp16.bin", "text_encoder/model.fp16.safetensors", ] - variant = "fp16" - self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) - - def test_transformer_model_is_compatible_variant_partial(self): - # pass variant but use the non-variant filenames - filenames = [ - "text_encoder/pytorch_model.bin", - "text_encoder/model.safetensors", - ] - variant = "fp16" - self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + self.assertTrue(is_safetensors_compatible(filenames)) def test_transformer_model_is_not_compatible_variant(self): filenames = [ @@ -126,9 +111,45 @@ def test_transformer_model_is_not_compatible_variant(self): "vae/diffusion_pytorch_model.fp16.bin", "vae/diffusion_pytorch_model.fp16.safetensors", "text_encoder/pytorch_model.fp16.bin", - # 'text_encoder/model.fp16.safetensors', "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - variant = "fp16" - self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) + self.assertFalse(is_safetensors_compatible(filenames)) + + def test_transformers_is_compatible_sharded(self): + filenames = [ + "text_encoder/pytorch_model.bin", + "text_encoder/model-00001-of-00002.safetensors", + "text_encoder/model-00002-of-00002.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_transformers_is_compatible_variant_sharded(self): + filenames = [ + "text_encoder/pytorch_model.bin", + "text_encoder/model.fp16-00001-of-00002.safetensors", + "text_encoder/model.fp16-00001-of-00002.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_diffusers_is_compatible_sharded(self): + filenames = [ + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model-00001-of-00002.safetensors", + "unet/diffusion_pytorch_model-00002-of-00002.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_diffusers_is_compatible_variant_sharded(self): + filenames = [ + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", + "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_diffusers_is_compatible_only_variants(self): + filenames = [ + "unet/diffusion_pytorch_model.fp16.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 1d37ae1dc2ca..c73a12a4cbf8 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -551,37 +551,94 @@ def test_download_variant_partly(self): assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert not any(f.endswith(other_format) for f in files) - def test_download_broken_variant(self): - for use_safetensors in [False, True]: - # text encoder is missing no variant and "no_ema" variant weights, so the following can't work - for variant in [None, "no_ema"]: - with self.assertRaises(OSError) as error_context: - with tempfile.TemporaryDirectory() as tmpdirname: - tmpdirname = StableDiffusionPipeline.from_pretrained( - "hf-internal-testing/stable-diffusion-broken-variants", - cache_dir=tmpdirname, - variant=variant, - use_safetensors=use_safetensors, - ) - - assert "Error no file name" in str(error_context.exception) - - # text encoder has fp16 variants so we can load it - with tempfile.TemporaryDirectory() as tmpdirname: - tmpdirname = StableDiffusionPipeline.download( + def test_download_safetensors_only_variant_exists_for_model(self): + variant = None + use_safetensors = True + + # text encoder is missing no variant weights, so the following can't work + with tempfile.TemporaryDirectory() as tmpdirname: + with self.assertRaises(OSError) as error_context: + tmpdirname = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/stable-diffusion-broken-variants", + cache_dir=tmpdirname, + variant=variant, use_safetensors=use_safetensors, + ) + assert "Error no file name" in str(error_context.exception) + + # text encoder has fp16 variants so we can load it + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-broken-variants", + use_safetensors=use_safetensors, + cache_dir=tmpdirname, + variant="fp16", + ) + all_root_files = [t[-1] for t in os.walk(tmpdirname)] + files = [item for sublist in all_root_files for item in sublist] + # None of the downloaded files should be a non-variant file even if we have some here: + # https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet + assert len(files) == 15, f"We should only download 15 files, not {len(files)}" + + def test_download_bin_only_variant_exists_for_model(self): + variant = None + use_safetensors = False + + # text encoder is missing Non-variant weights, so the following can't work + with tempfile.TemporaryDirectory() as tmpdirname: + with self.assertRaises(OSError) as error_context: + tmpdirname = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, - variant="fp16", + variant=variant, + use_safetensors=use_safetensors, ) + assert "Error no file name" in str(error_context.exception) - all_root_files = [t[-1] for t in os.walk(tmpdirname)] - files = [item for sublist in all_root_files for item in sublist] + # text encoder has fp16 variants so we can load it + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-broken-variants", + use_safetensors=use_safetensors, + cache_dir=tmpdirname, + variant="fp16", + ) + all_root_files = [t[-1] for t in os.walk(tmpdirname)] + files = [item for sublist in all_root_files for item in sublist] + # None of the downloaded files should be a non-variant file even if we have some here: + # https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet + assert len(files) == 15, f"We should only download 15 files, not {len(files)}" - # None of the downloaded files should be a non-variant file even if we have some here: - # https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet - assert len(files) == 15, f"We should only download 15 files, not {len(files)}" - # only unet has "no_ema" variant + def test_download_safetensors_variant_does_not_exist_for_model(self): + variant = "no_ema" + use_safetensors = True + + # text encoder is missing no_ema variant weights, so the following can't work + with tempfile.TemporaryDirectory() as tmpdirname: + with self.assertRaises(OSError) as error_context: + tmpdirname = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/stable-diffusion-broken-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + + assert "Error no file name" in str(error_context.exception) + + def test_download_bin_variant_does_not_exist_for_model(self): + variant = "no_ema" + use_safetensors = False + + # text encoder is missing no_ema variant weights, so the following can't work + with tempfile.TemporaryDirectory() as tmpdirname: + with self.assertRaises(OSError) as error_context: + tmpdirname = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/stable-diffusion-broken-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + assert "Error no file name" in str(error_context.exception) def test_local_save_load_index(self): prompt = "hello"