Skip to content

Commit 8b9d41a

Browse files
committed
NNXCombinedTimestepTextProjEmbeddings addition
1 parent 4daa38d commit 8b9d41a

2 files changed

Lines changed: 98 additions & 15 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,71 @@ def __call__(self, timestep, pooled_projection):
473473
conditioning = timestep_emb + pooled_projections
474474
return conditioning
475475

476+
class NNXCombinedTimestepTextProjEmbeddings(nnx.Module):
477+
def __init__(
478+
self,
479+
rngs: nnx.Rngs,
480+
in_features: int,
481+
hidden_size: int,
482+
embedding_dim: int,
483+
out_features: int = None,
484+
act_fn: str = "gelu_tanh",
485+
dtype: jnp.dtype = jnp.float32,
486+
weights_dtype: jnp.dtype = jnp.float32,
487+
precision: jax.lax.Precision = None,
488+
):
489+
if out_features is None:
490+
out_features = hidden_size
491+
492+
self.linear_1 = nnx.Linear(
493+
rngs=rngs,
494+
in_features=in_features,
495+
out_features=hidden_size,
496+
use_bias=True,
497+
dtype=jnp.float32,
498+
param_dtype=weights_dtype,
499+
precision=precision,
500+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp")),
501+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
502+
)
503+
self.act_1 = get_activation(act_fn)
504+
505+
self.linear_2 = nnx.Linear(
506+
rngs=rngs,
507+
in_features=hidden_size,
508+
out_features=out_features,
509+
use_bias=True,
510+
dtype=jnp.float32,
511+
param_dtype=weights_dtype,
512+
precision=precision,
513+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed")),
514+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
515+
)
516+
517+
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
518+
519+
class EmbWrapper(nnx.Module):
520+
def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype):
521+
self.timestep_embedder = NNXTimestepEmbedding(
522+
rngs=rngs,
523+
in_channels=256,
524+
time_embed_dim=embedding_dim,
525+
dtype=jnp.float32,
526+
weights_dtype=weights_dtype,
527+
)
528+
529+
self.emb = EmbWrapper(rngs, embedding_dim, weights_dtype)
530+
531+
def __call__(self, caption, timestep):
532+
hidden_states = self.linear_1(caption)
533+
hidden_states = self.act_1(hidden_states)
534+
hidden_states = self.linear_2(hidden_states)
535+
536+
timesteps_proj = self.time_proj(timestep)
537+
timesteps_emb = self.emb.timestep_embedder(timesteps_proj)
538+
539+
return hidden_states + timesteps_emb
540+
476541

477542
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
478543
embedding_dim: int

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
2323
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
24-
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection
24+
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection, NNXCombinedTimestepTextProjEmbeddings
2525
from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType
2626
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
2727
from maxdiffusion.common_types import BlockSizes
@@ -692,20 +692,38 @@ def __init__(
692692
)
693693

694694
# 2. Prompt embeddings
695-
self.caption_projection = NNXPixArtAlphaTextProjection(
696-
rngs=rngs,
697-
in_features=self.caption_channels,
698-
hidden_size=inner_dim,
699-
dtype=self.dtype,
700-
weights_dtype=self.weights_dtype,
701-
)
702-
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
703-
rngs=rngs,
704-
in_features=self.caption_channels,
705-
hidden_size=audio_inner_dim,
706-
dtype=self.dtype,
707-
weights_dtype=self.weights_dtype,
708-
)
695+
if self.cross_attn_mod:
696+
self.caption_projection = NNXCombinedTimestepTextProjEmbeddings(
697+
rngs=rngs,
698+
in_features=self.caption_channels,
699+
hidden_size=inner_dim,
700+
embedding_dim=inner_dim,
701+
dtype=self.dtype,
702+
weights_dtype=self.weights_dtype,
703+
)
704+
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
705+
rngs=rngs,
706+
in_features=self.caption_channels,
707+
hidden_size=audio_inner_dim,
708+
embedding_dim=audio_inner_dim,
709+
dtype=self.dtype,
710+
weights_dtype=self.weights_dtype,
711+
)
712+
else:
713+
self.caption_projection = NNXPixArtAlphaTextProjection(
714+
rngs=rngs,
715+
in_features=self.caption_channels,
716+
hidden_size=inner_dim,
717+
dtype=self.dtype,
718+
weights_dtype=self.weights_dtype,
719+
)
720+
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
721+
rngs=rngs,
722+
in_features=self.caption_channels,
723+
hidden_size=audio_inner_dim,
724+
dtype=self.dtype,
725+
weights_dtype=self.weights_dtype,
726+
)
709727
# 3. Timestep Modulation Params and Embedding
710728
num_mod_params = 9 if self.cross_attn_mod else 6
711729
self.time_embed = LTX2AdaLayerNormSingle(

0 commit comments

Comments
 (0)