Skip to content

Commit 79b084c

Browse files
committed
debug added for learnable registers
1 parent 61398b2 commit 79b084c

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@ def patched_fe_call(self, hidden_states, attention_mask):
7171
return out
7272
LTX2GemmaFeatureExtractor.__call__ = patched_fe_call
7373

74+
from maxdiffusion.models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector
75+
76+
orig_replace = Embeddings1DConnector._replace_padded_with_learnable_registers
77+
def patched_replace(self, hidden_states, attention_mask):
78+
regs = self.learnable_registers.value
79+
jax.debug.print("[MAXDIFFUSION] Connector Registers std: {std:.5f}, mean: {mean:.5f}, min: {min:.5f}",
80+
std=jnp.std(regs), mean=jnp.mean(regs), min=jnp.min(regs))
81+
82+
return orig_replace(self, hidden_states, attention_mask)
83+
84+
Embeddings1DConnector._replace_padded_with_learnable_registers = patched_replace
85+
7486
# Patch Transformer forward pass to intercept inputs and EXIT EARLY
7587
orig_transformer_forward_pass = pipe_module.transformer_forward_pass
7688
def patched_transformer_forward_pass(*args, **kwargs):

0 commit comments

Comments
 (0)