Skip to content

Commit a36a679

Browse files
committed
norm1, attn1, norm2, ff debug
1 parent 56f3906 commit a36a679

1 file changed

Lines changed: 24 additions & 1 deletion

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,30 @@ def patched_replace(self, hidden_states, attention_mask):
103103
def patched_block_call(self, hidden_states, attention_mask=None, rotary_emb=None):
104104
jax.debug.print("[MAXDIFFUSION] Block Input std: {std}, min: {min}, max: {max}",
105105
std=jnp.std(hidden_states), min=jnp.min(hidden_states), max=jnp.max(hidden_states))
106-
return orig_block_call(self, hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
106+
107+
# 1. Norm -> Attention
108+
normed1 = self.norm1(hidden_states)
109+
jax.debug.print(" [MAXDIFFUSION] norm1 std: {std}, min: {min}, max: {max}",
110+
std=jnp.std(normed1), min=jnp.min(normed1), max=jnp.max(normed1))
111+
112+
attn_output = self.attn1(normed1, attention_mask=attention_mask, rotary_emb=rotary_emb)
113+
jax.debug.print(" [MAXDIFFUSION] attn1 std: {std}, min: {min}, max: {max}",
114+
std=jnp.std(attn_output), min=jnp.min(attn_output), max=jnp.max(attn_output))
115+
116+
hidden_states = hidden_states + attn_output
117+
118+
# 2. Norm -> FeedForward
119+
normed2 = self.norm2(hidden_states)
120+
jax.debug.print(" [MAXDIFFUSION] norm2 std: {std}, min: {min}, max: {max}",
121+
std=jnp.std(normed2), min=jnp.min(normed2), max=jnp.max(normed2))
122+
123+
ff_output = self.ff(normed2)
124+
jax.debug.print(" [MAXDIFFUSION] ff std: {std}, min: {min}, max: {max}",
125+
std=jnp.std(ff_output), min=jnp.min(ff_output), max=jnp.max(ff_output))
126+
127+
hidden_states = hidden_states + ff_output
128+
129+
return hidden_states
107130

108131
_BasicTransformerBlock1D.__call__ = patched_block_call
109132

0 commit comments

Comments
 (0)