Skip to content

Commit fde6d7f

Browse files
committed
before_transformer parity file
1 parent 7ebe7fb commit fde6d7f

1 file changed

Lines changed: 15 additions & 8 deletions

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,22 @@ def print_stat(name, t):
5454
else:
5555
_print_stat_impl(name, t)
5656

57-
# Patch Connectors
58-
orig_connector_call = LTX2AudioVideoGemmaTextEncoder.__call__
59-
def patched_connector_call(self, hidden_states, attention_mask):
60-
out = orig_connector_call(self, hidden_states, attention_mask)
61-
print("\n=== CONNECTORS OUTPUTS ===")
62-
print_stat("connectors_video", out[0])
63-
print_stat("connectors_audio", out[1])
57+
from maxdiffusion.models.ltx2.text_encoders.feature_extractor_ltx2 import LTX2GemmaFeatureExtractor, _norm_and_concat_padded_batch
58+
59+
# Patch Feature Extractor
60+
orig_fe_call = LTX2GemmaFeatureExtractor.__call__
61+
def patched_fe_call(self, hidden_states, attention_mask):
62+
if isinstance(hidden_states, (tuple, list)):
63+
x = jnp.stack(hidden_states, axis=-1)
64+
else:
65+
x = hidden_states
66+
x_norm = _norm_and_concat_padded_batch(x, attention_mask)
67+
print("\n=== FEATURE EXTRACTOR / TEXT PROJ OUTPUTS ===")
68+
print_stat("packed_text_embeds", x_norm)
69+
out = self.linear(x_norm)
70+
print_stat("text_proj_out", out)
6471
return out
65-
LTX2AudioVideoGemmaTextEncoder.__call__ = patched_connector_call
72+
LTX2GemmaFeatureExtractor.__call__ = patched_fe_call
6673

6774
# Patch Transformer forward pass to intercept inputs and EXIT EARLY
6875
orig_transformer_forward_pass = pipe_module.transformer_forward_pass

0 commit comments

Comments
 (0)