We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 94cd520 commit 449d78cCopy full SHA for 449d78c
1 file changed
src/maxdiffusion/models/ltx2/transformer_ltx2.py
@@ -353,6 +353,9 @@ def __call__(
353
axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))
354
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names_audio)
355
356
+ jax.debug.print("[LTX2] Video hidden states shape: {shape}", shape=hidden_states.shape)
357
+ jax.debug.print("[LTX2] Audio hidden states shape: {shape}", shape=audio_hidden_states.shape)
358
+
359
if encoder_hidden_states is not None:
360
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
361
if audio_encoder_hidden_states is not None:
0 commit comments