Skip to content

Commit 943c1de

Browse files
committed
debug
1 parent b5539a4 commit 943c1de

3 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ def __call__(
10751075
causal: bool = False,
10761076
deterministic: bool = True,
10771077
) -> jax.Array:
1078-
print(f"[LTX2 XPROF Tracing] Decoder __call__ input shape: {sample.shape}")
1078+
print(f"[LTX2 XPROF Tracing] Video Decoder __call__ input shape: {sample.shape}")
10791079
if self.timestep_scale_multiplier is not None and temb is not None:
10801080
temb = temb * self.timestep_scale_multiplier.value
10811081

@@ -1587,7 +1587,7 @@ def decode(
15871587
generator: Optional[jax.Array] = None,
15881588
causal: Optional[bool] = None,
15891589
) -> Union[FlaxDecoderOutput, Tuple[jax.Array]]:
1590-
print(f"[LTX2 XPROF Tracing] VAE decode input shape: {latents.shape}")
1590+
print(f"[LTX2 XPROF Tracing] Video VAE decode input shape: {latents.shape}")
15911591
causal = self.decoder_causal if causal is None else causal
15921592
key = generator
15931593

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,7 @@ def __call__(
935935
print(f"[LTX2 XPROF Tracing] hidden_states shape: {hidden_states.shape}")
936936
print(f"[LTX2 XPROF Tracing] audio_hidden_states shape: {audio_hidden_states.shape}")
937937
print(f"[LTX2 XPROF Tracing] encoder_hidden_states shape: {encoder_hidden_states.shape}")
938+
print(f"[LTX2 XPROF Tracing] audio_encoder_hidden_states shape: {audio_encoder_hidden_states.shape}")
938939

939940
# 1. Prepare RoPE positional embeddings
940941
with jax.named_scope("RoPE Preparation"):

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,6 +1482,7 @@ def __call__(
14821482
# Post-process video (converts to numpy/PIL)
14831483
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
14841484
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
1485+
max_logging.log(f"[LTX2 XPROF] Produced video shape (B, C, T, H, W): {video_np.shape}")
14851486
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
14861487

14871488
# Decode Audio
@@ -1494,6 +1495,7 @@ def __call__(
14941495

14951496
# Convert audio to numpy
14961497
audio = np.array(audio)
1498+
max_logging.log(f"[LTX2 XPROF] Produced audio shape: {audio.shape}")
14971499

14981500
return LTX2PipelineOutput(frames=video, audio=audio)
14991501

@@ -1531,6 +1533,7 @@ def transformer_forward_pass(
15311533
print(f"[LTX2 XPROF Tracing] latents shape: {latents.shape}")
15321534
print(f"[LTX2 XPROF Tracing] audio_latents shape: {audio_latents.shape}")
15331535
print(f"[LTX2 XPROF Tracing] encoder_hidden_states shape: {encoder_hidden_states.shape}")
1536+
print(f"[LTX2 XPROF Tracing] audio_encoder_hidden_states shape: {audio_encoder_hidden_states.shape}")
15341537

15351538
transformer = nnx.merge(graphdef, state)
15361539

0 commit comments

Comments
 (0)