@@ -489,22 +489,9 @@ def __init__(
489489 if out_features is None :
490490 out_features = hidden_size
491491
492- self .linear_1 = nnx .Linear (
492+ self .linear = nnx .Linear (
493493 rngs = rngs ,
494494 in_features = in_features ,
495- out_features = hidden_size ,
496- use_bias = True ,
497- dtype = jnp .float32 ,
498- param_dtype = weights_dtype ,
499- precision = precision ,
500- kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , "mlp" )),
501- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
502- )
503- self .act_1 = get_activation (act_fn )
504-
505- self .linear_2 = nnx .Linear (
506- rngs = rngs ,
507- in_features = hidden_size ,
508495 out_features = out_features ,
509496 use_bias = True ,
510497 dtype = jnp .float32 ,
@@ -529,9 +516,7 @@ def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype)
529516 self .emb = EmbWrapper (rngs , embedding_dim , weights_dtype )
530517
531518 def __call__ (self , caption , timestep ):
532- hidden_states = self .linear_1 (caption )
533- hidden_states = self .act_1 (hidden_states )
534- hidden_states = self .linear_2 (hidden_states )
519+ hidden_states = self .linear (caption )
535520
536521 timesteps_proj = self .time_proj (timestep )
537522 timesteps_emb = self .emb .timestep_embedder (timesteps_proj )
0 commit comments