Skip to content

Commit 42a97d7

Browse files
committed
NNX simple feed forward wrapper
1 parent 03d162b commit 42a97d7

2 files changed

Lines changed: 25 additions & 9 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,25 @@ def __call__(self, caption):
385385
return hidden_states
386386

387387

388+
class NNXSimpleLinearWrapper(nnx.Module):
389+
390+
def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, weights_dtype: jnp.dtype):
391+
super().__init__()
392+
self.linear = nnx.Linear(
393+
rngs=rngs,
394+
in_features=in_features,
395+
out_features=out_features,
396+
use_bias=True,
397+
dtype=jnp.float32,
398+
param_dtype=weights_dtype,
399+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed")),
400+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
401+
)
402+
403+
def __call__(self, x):
404+
return self.linear(x)
405+
406+
388407
class PixArtAlphaTextProjection(nn.Module):
389408
"""
390409
Projects caption embeddings. Also handles dropout for classifier-free guidance.

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121

2222
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
2323
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
24-
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection, NNXCombinedTimestepTextProjEmbeddings
24+
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection, NNXCombinedTimestepTextProjEmbeddings, NNXSimpleLinearWrapper
2525
from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType
2626
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
2727
from maxdiffusion.common_types import BlockSizes
2828

2929

30+
3031
class LTX2AdaLayerNormSingle(nnx.Module):
3132

3233
def __init__(
@@ -695,20 +696,16 @@ def __init__(
695696

696697
# 2. Prompt embeddings
697698
if self.cross_attn_mod:
698-
self.caption_projection = NNXCombinedTimestepTextProjEmbeddings(
699+
self.caption_projection = NNXSimpleLinearWrapper(
699700
rngs=rngs,
700701
in_features=self.caption_channels,
701-
hidden_size=self.cross_attention_dim,
702-
embedding_dim=inner_dim,
703-
dtype=self.dtype,
702+
out_features=self.cross_attention_dim,
704703
weights_dtype=self.weights_dtype,
705704
)
706-
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
705+
self.audio_caption_projection = NNXSimpleLinearWrapper(
707706
rngs=rngs,
708707
in_features=self.audio_caption_channels,
709-
hidden_size=self.audio_cross_attention_dim,
710-
embedding_dim=audio_inner_dim,
711-
dtype=self.dtype,
708+
out_features=self.audio_cross_attention_dim,
712709
weights_dtype=self.weights_dtype,
713710
)
714711
else:

0 commit comments

Comments
 (0)