@@ -385,25 +385,6 @@ def __call__(self, caption):
385385 return hidden_states
386386
387387
388- class NNXSimpleLinearWrapper (nnx .Module ):
389-
390- def __init__ (self , rngs : nnx .Rngs , in_features : int , out_features : int , weights_dtype : jnp .dtype ):
391- super ().__init__ ()
392- self .linear = nnx .Linear (
393- rngs = rngs ,
394- in_features = in_features ,
395- out_features = out_features ,
396- use_bias = True ,
397- dtype = jnp .float32 ,
398- param_dtype = weights_dtype ,
399- kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("mlp" , "embed" )),
400- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
401- )
402-
403- def __call__ (self , x ):
404- return self .linear (x )
405-
406-
407388class PixArtAlphaTextProjection (nn .Module ):
408389 """
409390 Projects caption embeddings. Also handles dropout for classifier-free guidance.
@@ -505,13 +486,19 @@ def __init__(
505486 weights_dtype : jnp .dtype = jnp .float32 ,
506487 precision : jax .lax .Precision = None ,
507488 ):
508- self .text_embedder = NNXPixArtAlphaTextProjection (
489+ if out_features is None :
490+ out_features = hidden_size
491+
492+ self .linear = nnx .Linear (
509493 rngs = rngs ,
510494 in_features = in_features ,
511- hidden_size = embedding_dim ,
512- dtype = dtype ,
513- weights_dtype = weights_dtype ,
495+ out_features = out_features ,
496+ use_bias = True ,
497+ dtype = jnp .float32 ,
498+ param_dtype = weights_dtype ,
514499 precision = precision ,
500+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("mlp" , "embed" )),
501+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
515502 )
516503
517504 self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
@@ -529,7 +516,7 @@ def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype)
529516 self .emb = EmbWrapper (rngs , embedding_dim , weights_dtype )
530517
531518 def __call__ (self , caption , timestep ):
532- hidden_states = self .text_embedder (caption )
519+ hidden_states = self .linear (caption )
533520
534521 timesteps_proj = self .time_proj (timestep )
535522 timesteps_emb = self .emb .timestep_embedder (timesteps_proj )
0 commit comments