Skip to content

Commit 12ec691

Browse files
committed
fix
1 parent cc5f2eb commit 12ec691

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxdiffusion/models/ltx_2/transformer_ltx2.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,19 +208,21 @@ def __init__(
208208
weights_dtype=weights_dtype
209209
)
210210

211-
scale_rng, init_rng = nnx.split_rngs(rngs, "params", "initialization")
211+
212+
key = rngs.params()
213+
k1, k2, k3, k4 = jax.random.split(key, 4)
212214

213215
self.scale_shift_table = nnx.Param(
214-
jax.random.normal(init_rng(), (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim)
216+
jax.random.normal(k1, (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim)
215217
)
216218
self.audio_scale_shift_table = nnx.Param(
217-
jax.random.normal(init_rng(), (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim)
219+
jax.random.normal(k2, (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim)
218220
)
219221
self.video_a2v_cross_attn_scale_shift_table = nnx.Param(
220-
jax.random.normal(init_rng(), (5, self.dim), dtype=weights_dtype)
222+
jax.random.normal(k3, (5, self.dim), dtype=weights_dtype)
221223
)
222224
self.audio_a2v_cross_attn_scale_shift_table = nnx.Param(
223-
jax.random.normal(init_rng(), (5, audio_dim), dtype=weights_dtype)
225+
jax.random.normal(k4, (5, audio_dim), dtype=weights_dtype)
224226
)
225227

226228
def __call__(

0 commit comments

Comments
 (0)