Skip to content

Commit 449d78c

Browse files
committed
printing audio and video hidden states shape
1 parent 94cd520 commit 449d78c

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@ def __call__(
353353
axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))
354354
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names_audio)
355355

356+
jax.debug.print("[LTX2] Video hidden states shape: {shape}", shape=hidden_states.shape)
357+
jax.debug.print("[LTX2] Audio hidden states shape: {shape}", shape=audio_hidden_states.shape)
358+
356359
if encoder_hidden_states is not None:
357360
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
358361
if audio_encoder_hidden_states is not None:

0 commit comments

Comments
 (0)