|
21 | 21 |
|
22 | 22 | from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed |
23 | 23 | from maxdiffusion.models.attention_flax import NNXSimpleFeedForward |
24 | | -from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection, NNXCombinedTimestepTextProjEmbeddings |
| 24 | +from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection, NNXCombinedTimestepTextProjEmbeddings, NNXSimpleLinearWrapper |
25 | 25 | from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType |
26 | 26 | from maxdiffusion.configuration_utils import ConfigMixin, register_to_config |
27 | 27 | from maxdiffusion.common_types import BlockSizes |
28 | 28 |
|
29 | 29 |
|
| 30 | + |
30 | 31 | class LTX2AdaLayerNormSingle(nnx.Module): |
31 | 32 |
|
32 | 33 | def __init__( |
@@ -695,20 +696,16 @@ def __init__( |
695 | 696 |
|
696 | 697 | # 2. Prompt embeddings |
697 | 698 | if self.cross_attn_mod: |
698 | | - self.caption_projection = NNXCombinedTimestepTextProjEmbeddings( |
| 699 | + self.caption_projection = NNXSimpleLinearWrapper( |
699 | 700 | rngs=rngs, |
700 | 701 | in_features=self.caption_channels, |
701 | | - hidden_size=self.cross_attention_dim, |
702 | | - embedding_dim=inner_dim, |
703 | | - dtype=self.dtype, |
| 702 | + out_features=self.cross_attention_dim, |
704 | 703 | weights_dtype=self.weights_dtype, |
705 | 704 | ) |
706 | | - self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings( |
| 705 | + self.audio_caption_projection = NNXSimpleLinearWrapper( |
707 | 706 | rngs=rngs, |
708 | 707 | in_features=self.audio_caption_channels, |
709 | | - hidden_size=self.audio_cross_attention_dim, |
710 | | - embedding_dim=audio_inner_dim, |
711 | | - dtype=self.dtype, |
| 708 | + out_features=self.audio_cross_attention_dim, |
712 | 709 | weights_dtype=self.weights_dtype, |
713 | 710 | ) |
714 | 711 | else: |
|
0 commit comments