@@ -486,19 +486,13 @@ def __init__(
486486 weights_dtype : jnp .dtype = jnp .float32 ,
487487 precision : jax .lax .Precision = None ,
488488 ):
489- if out_features is None :
490- out_features = hidden_size
491-
492- self .linear = nnx .Linear (
489+ self .text_embedder = NNXPixArtAlphaTextProjection (
493490 rngs = rngs ,
494491 in_features = in_features ,
495- out_features = out_features ,
496- use_bias = True ,
497- dtype = jnp .float32 ,
498- param_dtype = weights_dtype ,
492+ hidden_size = embedding_dim ,
493+ dtype = dtype ,
494+ weights_dtype = weights_dtype ,
499495 precision = precision ,
500- kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("mlp" , "embed" )),
501- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
502496 )
503497
504498 self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
@@ -516,7 +510,7 @@ def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype)
516510 self .emb = EmbWrapper (rngs , embedding_dim , weights_dtype )
517511
518512 def __call__ (self , caption , timestep ):
519- hidden_states = self .linear (caption )
513+ hidden_states = self .text_embedder (caption )
520514
521515 timesteps_proj = self .time_proj (timestep )
522516 timesteps_emb = self .emb .timestep_embedder (timesteps_proj )
0 commit comments