Skip to content

Commit ff2c164

Browse files
committed
fix
1 parent 3ec5421 commit ff2c164

1 file changed

Lines changed: 2 additions & 10 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,6 @@ def __init__(
158158
rope_type=rope_type,
159159
)
160160

161-
# Scale Shift Tables
162-
self.scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim))
163-
self.audio_scale_shift_table = nnx.Param(
164-
jax.random.normal(rngs.params(), (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim)
165-
)
166-
self.video_a2v_cross_attn_scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (5, self.dim), dtype=weights_dtype))
167-
self.audio_a2v_cross_attn_scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (5, audio_dim), dtype=weights_dtype))
168-
169161
# 2. Prompt Cross-Attention
170162
self.norm2 = nnx.RMSNorm(
171163
self.dim,
@@ -815,7 +807,7 @@ def init_block(rngs):
815807
# 6. Output layers
816808
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
817809
self.norm_out = nnx.LayerNorm(
818-
inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
810+
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
819811
)
820812
self.proj_out = nnx.Linear(
821813
inner_dim,
@@ -828,7 +820,7 @@ def init_block(rngs):
828820
)
829821

830822
self.audio_norm_out = nnx.LayerNorm(
831-
audio_inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
823+
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
832824
)
833825
self.audio_proj_out = nnx.Linear(
834826
audio_inner_dim,

0 commit comments

Comments
 (0)