@@ -501,3 +501,38 @@ def __call__(self, timestep, guidance, pooled_projection):
501501 conditioning = time_guidance_emb + pooled_projections
502502
503503 return conditioning
504+
505+
506+ class NNXTimesteps (nnx .Module ):
507+ def __init__ (self , num_channels : int , flip_sin_to_cos : bool , downscale_freq_shift : float , scale : int = 1 ):
508+ self .num_channels = num_channels
509+ self .flip_sin_to_cos = flip_sin_to_cos
510+ self .downscale_freq_shift = downscale_freq_shift
511+ self .scale = scale
512+
513+ def __call__ (self , timesteps : jax .Array ) -> jax .Array :
514+ return get_sinusoidal_embeddings (
515+ timesteps = timesteps ,
516+ embedding_dim = self .num_channels ,
517+ freq_shift = self .downscale_freq_shift ,
518+ flip_sin_to_cos = self .flip_sin_to_cos ,
519+ scale = self .scale
520+ )
521+
522+
523+ class NNXPixArtAlphaCombinedTimestepSizeEmbeddings (nnx .Module ):
524+ def __init__ (self , rngs : nnx .Rngs , embedding_dim : int , size_emb_dim : int , dtype : jnp .dtype = jnp .float32 , weights_dtype : jnp .dtype = jnp .float32 ):
525+ self .outdim = size_emb_dim
526+ self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
527+ self .timestep_embedder = NNXTimestepEmbedding (
528+ rngs = rngs ,
529+ in_channels = 256 ,
530+ time_embed_dim = embedding_dim ,
531+ dtype = dtype ,
532+ weights_dtype = weights_dtype
533+ )
534+
535+ def __call__ (self , timestep : jax .Array , hidden_dtype : jnp .dtype = jnp .float32 ) -> jax .Array :
536+ timesteps_proj = self .time_proj (timestep )
537+ timesteps_emb = self .timestep_embedder (timesteps_proj .astype (hidden_dtype ))
538+ return timesteps_emb
0 commit comments