Skip to content

Commit 117f1ac

Browse files
committed
debug
1 parent 943c1de commit 117f1ac

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,8 @@ def _get_gemma_prompt_embeds(
769769

770770
text_encoder_hidden_states = text_encoder_outputs.hidden_states
771771
del text_encoder_outputs # Free memory
772+
max_logging.log(f"[LTX2 XPROF] Text Encoder (Gemma) produced {len(text_encoder_hidden_states)} layers.")
773+
max_logging.log(f"[LTX2 XPROF] Shape of first layer hidden states: {text_encoder_hidden_states[0].shape}")
772774

773775
prompt_embeds_list = []
774776
# Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
@@ -780,6 +782,7 @@ def _get_gemma_prompt_embeds(
780782
del text_encoder_hidden_states # Free PyTorch tensor memory
781783

782784
prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().to(torch.float32).numpy(), dtype=jnp.bool_)
785+
max_logging.log(f"[LTX2 XPROF] Prompt embeds produced. Number of layers/states: {len(prompt_embeds)}, shape of first: {prompt_embeds[0].shape}")
783786
else:
784787
raise ValueError("`text_encoder` is required to encode prompts.")
785788

@@ -1308,6 +1311,7 @@ def __call__(
13081311
with context_manager, axis_rules_context:
13091312
connectors_graphdef, connectors_state = nnx.split(self.connectors)
13101313

1314+
max_logging.log(f"[LTX2 XPROF] Running connectors with prompt_embeds shape: {prompt_embeds_jax.shape if not isinstance(prompt_embeds_jax, list) else len(prompt_embeds_jax)}")
13111315
video_embeds, audio_embeds, new_attention_mask = self._run_connectors(
13121316
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
13131317
)

0 commit comments

Comments
 (0)