@@ -769,6 +769,8 @@ def _get_gemma_prompt_embeds(
769769
770770 text_encoder_hidden_states = text_encoder_outputs .hidden_states
771771 del text_encoder_outputs # Free memory
772+ max_logging .log (f"[LTX2 XPROF] Text Encoder (Gemma) produced { len (text_encoder_hidden_states )} layers." )
773+ max_logging .log (f"[LTX2 XPROF] Shape of first layer hidden states: { text_encoder_hidden_states [0 ].shape } " )
772774
773775 prompt_embeds_list = []
774776 # Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
@@ -780,6 +782,7 @@ def _get_gemma_prompt_embeds(
780782 del text_encoder_hidden_states # Free PyTorch tensor memory
781783
782784 prompt_attention_mask = jnp .array (prompt_attention_mask .cpu ().to (torch .float32 ).numpy (), dtype = jnp .bool_ )
785+ max_logging .log (f"[LTX2 XPROF] Prompt embeds produced. Number of layers/states: { len (prompt_embeds )} , shape of first: { prompt_embeds [0 ].shape } " )
783786 else :
784787 raise ValueError ("`text_encoder` is required to encode prompts." )
785788
@@ -1308,6 +1311,7 @@ def __call__(
13081311 with context_manager , axis_rules_context :
13091312 connectors_graphdef , connectors_state = nnx .split (self .connectors )
13101313
1314+ max_logging .log (f"[LTX2 XPROF] Running connectors with prompt_embeds shape: { prompt_embeds_jax .shape if not isinstance (prompt_embeds_jax , list ) else len (prompt_embeds_jax )} " )
13111315 video_embeds , audio_embeds , new_attention_mask = self ._run_connectors (
13121316 connectors_graphdef , connectors_state , prompt_embeds_jax , prompt_attention_mask_jax .astype (jnp .bool_ )
13131317 )
0 commit comments