Skip to content

Commit 599ad19

Browse files
committed
attn1 debug verify
1 parent c15b36b commit 599ad19

1 file changed

Lines changed: 10 additions & 13 deletions

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,16 @@ def patched_replace(self, hidden_states, attention_mask):
9696

9797
orig_block_call = _BasicTransformerBlock1D.__call__
9898
def patched_block_call(self, hidden_states, attention_mask=None, rotary_emb=None):
99-
jax.debug.print("\n[MAXDIFFUSION W] to_q std: {k}, to_q bias: {b}",
100-
k=jnp.std(self.attn1.to_q.kernel), b=jnp.std(self.attn1.to_q.bias))
101-
jax.debug.print("[MAXDIFFUSION W] to_k std: {k}, to_k bias: {b}",
102-
k=jnp.std(self.attn1.to_k.kernel), b=jnp.std(self.attn1.to_k.bias))
103-
jax.debug.print("[MAXDIFFUSION W] to_v std: {k}, to_v bias: {b}",
104-
k=jnp.std(self.attn1.to_v.kernel), b=jnp.std(self.attn1.to_v.bias))
105-
jax.debug.print("[MAXDIFFUSION W] to_out std: {k}, to_out bias: {b}",
106-
k=jnp.std(self.attn1.to_out.kernel), b=jnp.std(self.attn1.to_out.bias))
107-
jax.debug.print("[MAXDIFFUSION W] norm_q std: {k}", k=jnp.std(self.attn1.norm_q.scale))
108-
109-
if attention_mask is not None:
110-
jax.debug.print("[MAXDIFFUSION MASK] supplied to attention kernel sum: {}", jnp.sum(attention_mask))
111-
99+
normed1 = self.norm1(hidden_states)
100+
jax.debug.print("DEBUG: maxdiffusion block norm1. min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}",
101+
min=jnp.min(normed1), max=jnp.max(normed1),
102+
mean=jnp.mean(normed1), std=jnp.std(normed1))
103+
104+
attn_output = self.attn1(normed1, attention_mask=attention_mask, rotary_emb=rotary_emb)
105+
jax.debug.print("DEBUG: maxdiffusion block attn1. min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}",
106+
min=jnp.min(attn_output), max=jnp.max(attn_output),
107+
mean=jnp.mean(attn_output), std=jnp.std(attn_output))
108+
112109
return orig_block_call(self, hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
113110

114111
_BasicTransformerBlock1D.__call__ = patched_block_call

0 commit comments

Comments
 (0)