Skip to content

Commit 7387fb2

Browse files
committed
transformer fix
1 parent bee4f96 commit 7387fb2

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,12 +795,16 @@ def __init__(
795795

796796
# 3. Output Layer Scale/Shift Modulation parameters
797797
param_rng = rngs.params()
798+
# LTX 2.3 uses 2 parameters for scale/shift in output layer, LTX2 uses 6.
799+
num_scale_shift_params = 2 if gated_attn else 6
800+
num_audio_scale_shift_params = 2 if audio_gated_attn else 6
801+
798802
self.scale_shift_table = nnx.Param(
799-
jax.random.normal(param_rng, (6, inner_dim), dtype=self.weights_dtype) / jnp.sqrt(inner_dim),
803+
jax.random.normal(param_rng, (num_scale_shift_params, inner_dim), dtype=self.weights_dtype) / jnp.sqrt(inner_dim),
800804
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")),
801805
)
802806
self.audio_scale_shift_table = nnx.Param(
803-
jax.random.normal(param_rng, (6, audio_inner_dim), dtype=self.weights_dtype) / jnp.sqrt(audio_inner_dim),
807+
jax.random.normal(param_rng, (num_audio_scale_shift_params, audio_inner_dim), dtype=self.weights_dtype) / jnp.sqrt(audio_inner_dim),
804808
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")),
805809
)
806810

0 commit comments

Comments
 (0)