@@ -473,6 +473,71 @@ def __call__(self, timestep, pooled_projection):
473473 conditioning = timestep_emb + pooled_projections
474474 return conditioning
475475
476+ class NNXCombinedTimestepTextProjEmbeddings (nnx .Module ):
477+ def __init__ (
478+ self ,
479+ rngs : nnx .Rngs ,
480+ in_features : int ,
481+ hidden_size : int ,
482+ embedding_dim : int ,
483+ out_features : int = None ,
484+ act_fn : str = "gelu_tanh" ,
485+ dtype : jnp .dtype = jnp .float32 ,
486+ weights_dtype : jnp .dtype = jnp .float32 ,
487+ precision : jax .lax .Precision = None ,
488+ ):
489+ if out_features is None :
490+ out_features = hidden_size
491+
492+ self .linear_1 = nnx .Linear (
493+ rngs = rngs ,
494+ in_features = in_features ,
495+ out_features = hidden_size ,
496+ use_bias = True ,
497+ dtype = jnp .float32 ,
498+ param_dtype = weights_dtype ,
499+ precision = precision ,
500+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , "mlp" )),
501+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
502+ )
503+ self .act_1 = get_activation (act_fn )
504+
505+ self .linear_2 = nnx .Linear (
506+ rngs = rngs ,
507+ in_features = hidden_size ,
508+ out_features = out_features ,
509+ use_bias = True ,
510+ dtype = jnp .float32 ,
511+ param_dtype = weights_dtype ,
512+ precision = precision ,
513+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("mlp" , "embed" )),
514+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
515+ )
516+
517+ self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
518+
519+ class EmbWrapper (nnx .Module ):
520+ def __init__ (self , rngs : nnx .Rngs , embedding_dim : int , weights_dtype : jnp .dtype ):
521+ self .timestep_embedder = NNXTimestepEmbedding (
522+ rngs = rngs ,
523+ in_channels = 256 ,
524+ time_embed_dim = embedding_dim ,
525+ dtype = jnp .float32 ,
526+ weights_dtype = weights_dtype ,
527+ )
528+
529+ self .emb = EmbWrapper (rngs , embedding_dim , weights_dtype )
530+
531+ def __call__ (self , caption , timestep ):
532+ hidden_states = self .linear_1 (caption )
533+ hidden_states = self .act_1 (hidden_states )
534+ hidden_states = self .linear_2 (hidden_states )
535+
536+ timesteps_proj = self .time_proj (timestep )
537+ timesteps_emb = self .emb .timestep_embedder (timesteps_proj )
538+
539+ return hidden_states + timesteps_emb
540+
476541
477542class CombinedTimestepGuidanceTextProjEmbeddings (nn .Module ):
478543 embedding_dim : int
0 commit comments