We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 23f05d3 commit 38f7162Copy full SHA for 38f7162
1 file changed
src/maxdiffusion/models/ltx2/transformer_ltx2.py
@@ -350,7 +350,8 @@ def __call__(
350
351
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
352
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
353
- audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names)
+ axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))
354
+ audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names_audio)
355
356
if encoder_hidden_states is not None:
357
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
0 commit comments