Skip to content

Commit d9f55d2

Browse files
committed
revert
1 parent 42a97d7 commit d9f55d2

2 files changed

Lines changed: 19 additions & 28 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -385,25 +385,6 @@ def __call__(self, caption):
385385
return hidden_states
386386

387387

388-
class NNXSimpleLinearWrapper(nnx.Module):
389-
390-
def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, weights_dtype: jnp.dtype):
391-
super().__init__()
392-
self.linear = nnx.Linear(
393-
rngs=rngs,
394-
in_features=in_features,
395-
out_features=out_features,
396-
use_bias=True,
397-
dtype=jnp.float32,
398-
param_dtype=weights_dtype,
399-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed")),
400-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
401-
)
402-
403-
def __call__(self, x):
404-
return self.linear(x)
405-
406-
407388
class PixArtAlphaTextProjection(nn.Module):
408389
"""
409390
Projects caption embeddings. Also handles dropout for classifier-free guidance.
@@ -505,13 +486,19 @@ def __init__(
505486
weights_dtype: jnp.dtype = jnp.float32,
506487
precision: jax.lax.Precision = None,
507488
):
508-
self.text_embedder = NNXPixArtAlphaTextProjection(
489+
if out_features is None:
490+
out_features = hidden_size
491+
492+
self.linear = nnx.Linear(
509493
rngs=rngs,
510494
in_features=in_features,
511-
hidden_size=embedding_dim,
512-
dtype=dtype,
513-
weights_dtype=weights_dtype,
495+
out_features=out_features,
496+
use_bias=True,
497+
dtype=jnp.float32,
498+
param_dtype=weights_dtype,
514499
precision=precision,
500+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed")),
501+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
515502
)
516503

517504
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -529,7 +516,7 @@ def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype)
529516
self.emb = EmbWrapper(rngs, embedding_dim, weights_dtype)
530517

531518
def __call__(self, caption, timestep):
532-
hidden_states = self.text_embedder(caption)
519+
hidden_states = self.linear(caption)
533520

534521
timesteps_proj = self.time_proj(timestep)
535522
timesteps_emb = self.emb.timestep_embedder(timesteps_proj)

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -696,16 +696,20 @@ def __init__(
696696

697697
# 2. Prompt embeddings
698698
if self.cross_attn_mod:
699-
self.caption_projection = NNXSimpleLinearWrapper(
699+
self.caption_projection = NNXCombinedTimestepTextProjEmbeddings(
700700
rngs=rngs,
701701
in_features=self.caption_channels,
702-
out_features=self.cross_attention_dim,
702+
hidden_size=self.cross_attention_dim,
703+
embedding_dim=inner_dim,
704+
dtype=self.dtype,
703705
weights_dtype=self.weights_dtype,
704706
)
705-
self.audio_caption_projection = NNXSimpleLinearWrapper(
707+
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
706708
rngs=rngs,
707709
in_features=self.audio_caption_channels,
708-
out_features=self.audio_cross_attention_dim,
710+
hidden_size=self.audio_cross_attention_dim,
711+
embedding_dim=audio_inner_dim,
712+
dtype=self.dtype,
709713
weights_dtype=self.weights_dtype,
710714
)
711715
else:

0 commit comments

Comments
 (0)