@@ -35,29 +35,32 @@ def get_sinusoidal_embeddings(
3535 """Returns the positional encoding (same as Tensor2Tensor).
3636
3737 Args:
38- timesteps: a 1-D Tensor of N indices, one per batch element .
38+ timesteps: a 1-D or 2-D Tensor of indices.
3939 These may be fractional.
4040 embedding_dim: The number of output channels.
4141 min_timescale: The smallest time unit (should probably be 0.0).
4242 max_timescale: The largest time unit.
4343 Returns:
44- a Tensor of timing signals [N, num_channels]
44+ a Tensor of timing signals [B, num_channels] or [B, N, num_channels]
4545 """
46- assert timesteps .ndim == 1 , "Timesteps should be a 1d-array"
46+ assert timesteps .ndim <= 2 , "Timesteps should be a 1d or 2d -array"
4747 assert embedding_dim % 2 == 0 , f"Embedding dimension { embedding_dim } should be even"
4848 num_timescales = float (embedding_dim // 2 )
4949 log_timescale_increment = math .log (max_timescale / min_timescale ) / (num_timescales - freq_shift )
5050 inv_timescales = min_timescale * jnp .exp (jnp .arange (num_timescales , dtype = jnp .float32 ) * - log_timescale_increment )
51- emb = jnp .expand_dims (timesteps , 1 ) * jnp . expand_dims ( inv_timescales , 0 )
51+ emb = jnp .expand_dims (timesteps , - 1 ) * inv_timescales
5252
5353 # scale embeddings
5454 scaled_time = scale * emb
5555
5656 if flip_sin_to_cos :
57- signal = jnp .concatenate ([jnp .cos (scaled_time ), jnp .sin (scaled_time )], axis = 1 )
57+ signal = jnp .concatenate (
58+ [jnp .cos (scaled_time ), jnp .sin (scaled_time )], axis = - 1
59+ )
5860 else :
59- signal = jnp .concatenate ([jnp .sin (scaled_time ), jnp .cos (scaled_time )], axis = 1 )
60- signal = jnp .reshape (signal , [jnp .shape (timesteps )[0 ], embedding_dim ])
61+ signal = jnp .concatenate (
62+ [jnp .sin (scaled_time ), jnp .cos (scaled_time )], axis = - 1
63+ )
6164 return signal
6265
6366
@@ -84,7 +87,7 @@ def __init__(
8487 sample_proj_bias = True ,
8588 dtype : jnp .dtype = jnp .float32 ,
8689 weights_dtype : jnp .dtype = jnp .float32 ,
87- precision : jax .lax .Precision = None ,
90+ precision : jax .lax .Precision | None = None ,
8891 ):
8992 self .linear_1 = nnx .Linear (
9093 rngs = rngs ,
@@ -221,7 +224,7 @@ def __call__(self, timesteps):
221224
222225def get_1d_rotary_pos_embed (
223226 dim : int ,
224- pos : Union [jnp .array , int ],
227+ pos : Union [jnp .ndarray , int ],
225228 theta : float = 10000.0 ,
226229 linear_factor = 1.0 ,
227230 ntk_factor = 1.0 ,
@@ -332,11 +335,11 @@ def __init__(
332335 rngs : nnx .Rngs ,
333336 in_features : int ,
334337 hidden_size : int ,
335- out_features : int = None ,
338+ out_features : int | None = None ,
336339 act_fn : str = "gelu_tanh" ,
337340 dtype : jnp .dtype = jnp .float32 ,
338341 weights_dtype : jnp .dtype = jnp .float32 ,
339- precision : jax .lax .Precision = None ,
342+ precision : jax .lax .Precision | None = None ,
340343 ):
341344 if out_features is None :
342345 out_features = hidden_size
@@ -392,11 +395,11 @@ class PixArtAlphaTextProjection(nn.Module):
392395 """
393396
394397 hidden_size : int
395- out_features : int = None
398+ out_features : int | None = None
396399 act_fn : str = "gelu_tanh"
397400 dtype : jnp .dtype = jnp .float32
398401 weights_dtype : jnp .dtype = jnp .float32
399- precision : jax .lax .Precision = None
402+ precision : jax .lax .Precision | None = None
400403
401404 @nn .compact
402405 def __call__ (self , caption ):
@@ -455,7 +458,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
455458 pooled_projection_dim : int
456459 dtype : jnp .dtype = jnp .float32
457460 weights_dtype : jnp .dtype = jnp .float32
458- precision : jax .lax .Precision = None
461+ precision : jax .lax .Precision | None = None
459462
460463 @nn .compact
461464 def __call__ (self , timestep , pooled_projection ):
@@ -479,7 +482,7 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
479482 pooled_projection_dim : int
480483 dtype : jnp .dtype = jnp .float32
481484 weights_dtype : jnp .dtype = jnp .float32
482- precision : jax .lax .Precision = None
485+ precision : jax .lax .Precision | None = None
483486
484487 @nn .compact
485488 def __call__ (self , timestep , guidance , pooled_projection ):
0 commit comments