Skip to content

Commit 56f3906

Browse files
committed
prenorm debug
1 parent 6f9af83 commit 56f3906

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ def patched_replace(self, hidden_states, attention_mask):
9797

9898
Embeddings1DConnector._replace_padded_with_learnable_registers = patched_replace
9999

100+
from maxdiffusion.models.ltx2.text_encoders.embeddings_connector_ltx2 import _BasicTransformerBlock1D
101+
102+
orig_block_call = _BasicTransformerBlock1D.__call__
103+
def patched_block_call(self, hidden_states, attention_mask=None, rotary_emb=None):
104+
jax.debug.print("[MAXDIFFUSION] Block Input std: {std}, min: {min}, max: {max}",
105+
std=jnp.std(hidden_states), min=jnp.min(hidden_states), max=jnp.max(hidden_states))
106+
return orig_block_call(self, hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
107+
108+
_BasicTransformerBlock1D.__call__ = patched_block_call
109+
100110
# Patch Transformer forward pass to intercept inputs and EXIT EARLY
101111
orig_transformer_forward_pass = pipe_module.transformer_forward_pass
102112
def patched_transformer_forward_pass(*args, **kwargs):

0 commit comments

Comments
 (0)