@@ -96,19 +96,16 @@ def patched_replace(self, hidden_states, attention_mask):
9696
9797orig_block_call = _BasicTransformerBlock1D .__call__
9898def patched_block_call (self , hidden_states , attention_mask = None , rotary_emb = None ):
99- jax .debug .print ("\n [MAXDIFFUSION W] to_q std: {k}, to_q bias: {b}" ,
100- k = jnp .std (self .attn1 .to_q .kernel ), b = jnp .std (self .attn1 .to_q .bias ))
101- jax .debug .print ("[MAXDIFFUSION W] to_k std: {k}, to_k bias: {b}" ,
102- k = jnp .std (self .attn1 .to_k .kernel ), b = jnp .std (self .attn1 .to_k .bias ))
103- jax .debug .print ("[MAXDIFFUSION W] to_v std: {k}, to_v bias: {b}" ,
104- k = jnp .std (self .attn1 .to_v .kernel ), b = jnp .std (self .attn1 .to_v .bias ))
105- jax .debug .print ("[MAXDIFFUSION W] to_out std: {k}, to_out bias: {b}" ,
106- k = jnp .std (self .attn1 .to_out .kernel ), b = jnp .std (self .attn1 .to_out .bias ))
107- jax .debug .print ("[MAXDIFFUSION W] norm_q std: {k}" , k = jnp .std (self .attn1 .norm_q .scale ))
108-
109- if attention_mask is not None :
110- jax .debug .print ("[MAXDIFFUSION MASK] supplied to attention kernel sum: {}" , jnp .sum (attention_mask ))
111-
99+ normed1 = self .norm1 (hidden_states )
100+ jax .debug .print ("DEBUG: maxdiffusion block norm1. min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}" ,
101+ min = jnp .min (normed1 ), max = jnp .max (normed1 ),
102+ mean = jnp .mean (normed1 ), std = jnp .std (normed1 ))
103+
104+ attn_output = self .attn1 (normed1 , attention_mask = attention_mask , rotary_emb = rotary_emb )
105+ jax .debug .print ("DEBUG: maxdiffusion block attn1. min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}" ,
106+ min = jnp .min (attn_output ), max = jnp .max (attn_output ),
107+ mean = jnp .mean (attn_output ), std = jnp .std (attn_output ))
108+
112109 return orig_block_call (self , hidden_states , attention_mask = attention_mask , rotary_emb = rotary_emb )
113110
114111_BasicTransformerBlock1D .__call__ = patched_block_call
0 commit comments