@@ -54,15 +54,22 @@ def print_stat(name, t):
5454 else :
5555 _print_stat_impl (name , t )
5656
57- # Patch Connectors
58- orig_connector_call = LTX2AudioVideoGemmaTextEncoder .__call__
59- def patched_connector_call (self , hidden_states , attention_mask ):
60- out = orig_connector_call (self , hidden_states , attention_mask )
61- print ("\n === CONNECTORS OUTPUTS ===" )
62- print_stat ("connectors_video" , out [0 ])
63- print_stat ("connectors_audio" , out [1 ])
57+ from maxdiffusion .models .ltx2 .text_encoders .feature_extractor_ltx2 import LTX2GemmaFeatureExtractor , _norm_and_concat_padded_batch
58+
59+ # Patch Feature Extractor
60+ orig_fe_call = LTX2GemmaFeatureExtractor .__call__
61+ def patched_fe_call (self , hidden_states , attention_mask ):
62+ if isinstance (hidden_states , (tuple , list )):
63+ x = jnp .stack (hidden_states , axis = - 1 )
64+ else :
65+ x = hidden_states
66+ x_norm = _norm_and_concat_padded_batch (x , attention_mask )
67+ print ("\n === FEATURE EXTRACTOR / TEXT PROJ OUTPUTS ===" )
68+ print_stat ("packed_text_embeds" , x_norm )
69+ out = self .linear (x_norm )
70+ print_stat ("text_proj_out" , out )
6471 return out
65- LTX2AudioVideoGemmaTextEncoder .__call__ = patched_connector_call
72+ LTX2GemmaFeatureExtractor .__call__ = patched_fe_call
6673
6774# Patch Transformer forward pass to intercept inputs and EXIT EARLY
6875orig_transformer_forward_pass = pipe_module .transformer_forward_pass
0 commit comments