Skip to content

Commit

Permalink
Refac training utils.py (huggingface#9815)
Browse files Browse the repository at this point in the history
* Refac training utils.py

* quality

---------

Co-authored-by: sayakpaul <[email protected]>
  • Loading branch information
RogerSinghChugh and sayakpaul authored Nov 4, 2024
1 parent 13e8fde commit a3cc641
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def set_seed(seed: int):
Args:
seed (`int`): The seed to set.
Returns:
`None`
"""
random.seed(seed)
np.random.seed(seed)
Expand All @@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
for the given timesteps using the provided noise scheduler.
Args:
noise_scheduler (`NoiseScheduler`):
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
the SNR values.
timesteps (`torch.Tensor`):
A tensor of timesteps for which the SNR is computed.
Returns:
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
Expand Down

0 comments on commit a3cc641

Please sign in to comment.