Skip to content

Commit 407b0c6

Browse files
committed
NNXCombinedTimestepTextProjEmbeddings addition
1 parent 8b9d41a commit 407b0c6

1 file changed

Lines changed: 2 additions & 17 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -489,22 +489,9 @@ def __init__(
489489
if out_features is None:
490490
out_features = hidden_size
491491

492-
self.linear_1 = nnx.Linear(
492+
self.linear = nnx.Linear(
493493
rngs=rngs,
494494
in_features=in_features,
495-
out_features=hidden_size,
496-
use_bias=True,
497-
dtype=jnp.float32,
498-
param_dtype=weights_dtype,
499-
precision=precision,
500-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp")),
501-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
502-
)
503-
self.act_1 = get_activation(act_fn)
504-
505-
self.linear_2 = nnx.Linear(
506-
rngs=rngs,
507-
in_features=hidden_size,
508495
out_features=out_features,
509496
use_bias=True,
510497
dtype=jnp.float32,
@@ -529,9 +516,7 @@ def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype)
529516
self.emb = EmbWrapper(rngs, embedding_dim, weights_dtype)
530517

531518
def __call__(self, caption, timestep):
532-
hidden_states = self.linear_1(caption)
533-
hidden_states = self.act_1(hidden_states)
534-
hidden_states = self.linear_2(hidden_states)
519+
hidden_states = self.linear(caption)
535520

536521
timesteps_proj = self.time_proj(timestep)
537522
timesteps_emb = self.emb.timestep_embedder(timesteps_proj)

0 commit comments

Comments
 (0)