Skip to content

Commit 27225eb

Browse files
committed
transformer fix
1 parent e847423 commit 27225eb

1 file changed

Lines changed: 29 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
rope_type: str = "interleaved",
102102
video_gated_attn: bool = False,
103103
audio_gated_attn: bool = False,
104+
cross_attn_mod: bool = False,
104105
dtype: jnp.dtype = jnp.float32,
105106
weights_dtype: jnp.dtype = jnp.float32,
106107
mesh: jax.sharding.Mesh = None,
@@ -320,14 +321,17 @@ def __init__(
320321
)
321322

322323
key = rngs.params()
323-
k1, k2, k3, k4 = jax.random.split(key, 4)
324+
k1, k2, k3, k4, k5, k6 = jax.random.split(key, 6)
324325

326+
# LTX 2.3 uses 9 parameters for scale/shift in blocks, LTX2 uses 6.
327+
num_block_scale_shift_params = 9 if cross_attn_mod else 6
328+
325329
self.scale_shift_table = nnx.Param(
326-
jax.random.normal(k1, (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim),
330+
jax.random.normal(k1, (num_block_scale_shift_params, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim),
327331
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
328332
)
329333
self.audio_scale_shift_table = nnx.Param(
330-
jax.random.normal(k2, (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim),
334+
jax.random.normal(k2, (num_block_scale_shift_params, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim),
331335
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
332336
)
333337
self.video_a2v_cross_attn_scale_shift_table = nnx.Param(
@@ -339,6 +343,16 @@ def __init__(
339343
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
340344
)
341345

346+
if cross_attn_mod:
347+
self.prompt_scale_shift_table = nnx.Param(
348+
jax.random.normal(k5, (2, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim),
349+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
350+
)
351+
self.audio_prompt_scale_shift_table = nnx.Param(
352+
jax.random.normal(k6, (2, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim),
353+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
354+
)
355+
342356
def __call__(
343357
self,
344358
hidden_states: jax.Array, # Video
@@ -610,6 +624,9 @@ def __init__(
610624
qk_norm: str = "rms_norm_across_heads",
611625
flash_block_sizes: BlockSizes = None,
612626
flash_min_seq_length: int = 4096,
627+
video_gated_attn: bool = False,
628+
audio_gated_attn: bool = False,
629+
cross_attn_mod: bool = False,
613630
**kwargs,
614631
):
615632
self.in_channels = in_channels
@@ -659,6 +676,9 @@ def __init__(
659676
self.a2v_attention_kernel = a2v_attention_kernel
660677
self.v2a_attention_kernel = v2a_attention_kernel
661678
self.flash_min_seq_length = flash_min_seq_length
679+
self.video_gated_attn = video_gated_attn
680+
self.audio_gated_attn = audio_gated_attn
681+
self.cross_attn_mod = cross_attn_mod
662682

663683
_out_channels = self.out_channels or self.in_channels
664684
_audio_out_channels = self.audio_out_channels or self.audio_in_channels
@@ -881,6 +901,9 @@ def init_block(rngs):
881901
v2a_attention_kernel=self.v2a_attention_kernel,
882902
flash_block_sizes=flash_block_sizes,
883903
flash_min_seq_length=self.flash_min_seq_length,
904+
video_gated_attn=self.video_gated_attn,
905+
audio_gated_attn=self.audio_gated_attn,
906+
cross_attn_mod=self.cross_attn_mod,
884907
)
885908

886909
if self.scan_layers:
@@ -916,6 +939,9 @@ def init_block(rngs):
916939
v2a_attention_kernel=self.v2a_attention_kernel,
917940
flash_block_sizes=flash_block_sizes,
918941
flash_min_seq_length=self.flash_min_seq_length,
942+
video_gated_attn=self.video_gated_attn,
943+
audio_gated_attn=self.audio_gated_attn,
944+
cross_attn_mod=self.cross_attn_mod,
919945
)
920946
blocks.append(block)
921947
self.transformer_blocks = nnx.List(blocks)

0 commit comments

Comments
 (0)