Skip to content

Commit

Permalink
[docs] refactoring docstrings in models/embeddings_flax.py (hugging…
Browse files Browse the repository at this point in the history
…face#9592)

* [docs] refactoring docstrings in `models/embeddings_flax.py`

* Update src/diffusers/models/embeddings_flax.py

* make style

---------

Co-authored-by: Aryan <[email protected]>
  • Loading branch information
Jwaminju and a-r-r-o-w authored Oct 15, 2024
1 parent fff4be8 commit a3e8d3f
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/diffusers/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
"""Returns the positional encoding (same as Tensor2Tensor).
Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
embedding_dim: The number of output channels.
min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
timesteps (`jnp.ndarray` of shape `(N,)`):
A 1-D array of N indices, one per batch element. These may be fractional.
embedding_dim (`int`):
The number of output channels.
freq_shift (`float`, *optional*, defaults to `1`):
Shift applied to the frequency scaling of the embeddings.
min_timescale (`float`, *optional*, defaults to `1`):
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
max_timescale (`float`, *optional*, defaults to `1.0e4`):
The largest time unit used in the sinusoidal calculation.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the order of sinusoidal components to cosine first.
scale (`float`, *optional*, defaults to `1.0`):
A scaling factor applied to the positional embeddings.
Returns:
a Tensor of timing signals [N, num_channels]
"""
Expand Down Expand Up @@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Time step embedding dimension.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
The data type for the embedding parameters.
"""

time_embed_dim: int = 32
Expand All @@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
Args:
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
Time step embedding dimension.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sinusoidal function from sine to cosine.
freq_shift (`float`, *optional*, defaults to `1`):
Frequency shift applied to the sinusoidal embeddings.
"""

dim: int = 32
Expand Down

0 comments on commit a3e8d3f

Please sign in to comment.