Skip to content

Commit 38f7162

Browse files
committed
replicating audio across cp
1 parent 23f05d3 commit 38f7162

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def __call__(
350350

351351
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
352352
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
353-
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names)
353+
axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))
354+
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names_audio)
354355

355356
if encoder_hidden_states is not None:
356357
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)

0 commit comments

Comments
 (0)