diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 7186cb181aed..45fee35ef336 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -132,9 +132,58 @@ def _pinned_memory_tensors(self): finally: pinned_dict = None + def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): + tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream and current_stream is not None: + tensor.data.record_stream(current_stream) + + def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None): + for group_module in self.modules: + for param in group_module.parameters(): + source = pinned_memory[param] if pinned_memory else param.data + self._transfer_tensor_to_device(param, source, current_stream) + for buffer in group_module.buffers(): + source = pinned_memory[buffer] if pinned_memory else buffer.data + self._transfer_tensor_to_device(buffer, source, current_stream) + + for param in self.parameters: + source = pinned_memory[param] if pinned_memory else param.data + self._transfer_tensor_to_device(param, source, current_stream) + + for buffer in self.buffers: + source = pinned_memory[buffer] if pinned_memory else buffer.data + self._transfer_tensor_to_device(buffer, source, current_stream) + + def _onload_from_disk(self, current_stream): + if self.stream is not None: + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + + for key, tensor_obj in self.key_to_tensor.items(): + self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key] + + with self._pinned_memory_tensors() as pinned_memory: + for key, tensor_obj in self.key_to_tensor.items(): + self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream) + + self.cpu_param_dict.clear() + + else: + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + + def _onload_from_memory(self, current_stream): + if self.stream is not None: + with self._pinned_memory_tensors() as pinned_memory: + self._process_tensors_from_modules(pinned_memory, current_stream) + else: + self._process_tensors_from_modules(None, current_stream) + @torch.compiler.disable() def onload_(self): - r"""Onloads the group of modules to the onload_device.""" torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) if hasattr(torch, "accelerator") @@ -172,67 +221,30 @@ def onload_(self): self.stream.synchronize() with context: - if self.stream is not None: - with self._pinned_memory_tensors() as pinned_memory: - for group_module in self.modules: - for param in group_module.parameters(): - param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - param.data.record_stream(current_stream) - for buffer in group_module.buffers(): - buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) - - for param in self.parameters: - param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - param.data.record_stream(current_stream) - - for buffer in self.buffers: - buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) - + if self.offload_to_disk_path: + self._onload_from_disk(current_stream) else: - for group_module in self.modules: - for param in group_module.parameters(): - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - for buffer in group_module.buffers(): - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) - - for param in self.parameters: - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - - for buffer in self.buffers: - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) - - @torch.compiler.disable() - def offload_(self): - r"""Offloads the group of modules to the offload_device.""" - if self.offload_to_disk_path: - # TODO: we can potentially optimize this code path by checking if the _all_ the desired - # safetensor files exist on the disk and if so, skip this step entirely, reducing IO - # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not - # we perform a write. - # Check if the file has been saved in this session or if it already exists on disk. - if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): - os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = { - key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() - } - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) - - # The group is now considered offloaded to disk for the rest of the session. - self._is_offloaded_to_disk = True - - # We do this to free up the RAM which is still holding the up tensor data. - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) - return - + self._onload_from_memory(current_stream) + + def _offload_to_disk(self): + # TODO: we can potentially optimize this code path by checking if the _all_ the desired + # safetensor files exist on the disk and if so, skip this step entirely, reducing IO + # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # we perform a write. + # Check if the file has been saved in this session or if it already exists on disk. + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + + # The group is now considered offloaded to disk for the rest of the session. + self._is_offloaded_to_disk = True + + # We do this to free up the RAM which is still holding the up tensor data. + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + + def _offload_to_memory(self): torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) if hasattr(torch, "accelerator") @@ -257,6 +269,14 @@ def offload_(self): for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + @torch.compiler.disable() + def offload_(self): + r"""Offloads the group of modules to the offload_device.""" + if self.offload_to_disk_path: + self._offload_to_disk() + else: + self._offload_to_memory() + class GroupOffloadingHook(ModelHook): r"""