Skip to content

Commit 03d162b

Browse files
committed
embeddings_flax.py changes
1 parent 1c0d221 commit 03d162b

1 file changed

Lines changed: 5 additions & 11 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -486,19 +486,13 @@ def __init__(
486486
weights_dtype: jnp.dtype = jnp.float32,
487487
precision: jax.lax.Precision = None,
488488
):
489-
if out_features is None:
490-
out_features = hidden_size
491-
492-
self.linear = nnx.Linear(
489+
self.text_embedder = NNXPixArtAlphaTextProjection(
493490
rngs=rngs,
494491
in_features=in_features,
495-
out_features=out_features,
496-
use_bias=True,
497-
dtype=jnp.float32,
498-
param_dtype=weights_dtype,
492+
hidden_size=embedding_dim,
493+
dtype=dtype,
494+
weights_dtype=weights_dtype,
499495
precision=precision,
500-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed")),
501-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
502496
)
503497

504498
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -516,7 +510,7 @@ def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype)
516510
self.emb = EmbWrapper(rngs, embedding_dim, weights_dtype)
517511

518512
def __call__(self, caption, timestep):
519-
hidden_states = self.linear(caption)
513+
hidden_states = self.text_embedder(caption)
520514

521515
timesteps_proj = self.time_proj(timestep)
522516
timesteps_emb = self.emb.timestep_embedder(timesteps_proj)

0 commit comments

Comments
 (0)