Skip to content

Commit

Permalink
fix load sharded checkpoint from a subfolder (local path) (huggingfac…
Browse files Browse the repository at this point in the history
…e#8913)

fix

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
yiyixuxu and sayakpaul authored Aug 1, 2024
1 parent c646fbc commit 95a7832
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 24 deletions.
50 changes: 26 additions & 24 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
return pretrained_model_name_or_path, sharded_metadata
return shards_path, sharded_metadata

# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
Expand All @@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
"required according to the checkpoint index."
)

try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)

# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e

# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
if local_files_only:
elif local_files_only:
_check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)

return cached_folder, sharded_metadata

Expand Down
34 changes: 34 additions & 0 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,17 @@ def test_load_sharded_checkpoint_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)

assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand All @@ -1077,6 +1088,17 @@ def test_load_sharded_checkpoint_device_map_from_hub(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)

assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand All @@ -1087,6 +1109,18 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)

assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down

0 comments on commit 95a7832

Please sign in to comment.