@@ -81,18 +81,13 @@ def patched_replace(self, hidden_states, attention_mask):
8181 mask = attention_mask .squeeze (- 1 ) # [B, T]
8282 curr_mask = (mask > 0.5 ).astype (jnp .int32 )
8383
84- jax .debug .print ("\n [MAXDIFFUSION] Mask Debug:" )
85- jax .debug .print (" Input Attn Mask min/max: {} / {}" , jnp .min (attention_mask ), jnp .max (attention_mask ))
86- jax .debug .print (" Curr Mask sum: {} (valid tokens)" , jnp .sum (curr_mask ))
87- jax .debug .print (" Curr Mask start 20 elements: {}" , curr_mask [0 , :20 ])
84+ jax .debug .print ("\n [MAXDIFFUSION MASK DEBUG]" )
85+ jax .debug .print (" Input Attn Mask sum: {}" , jnp .sum (attention_mask ))
86+ jax .debug .print (" Curr Mask sum: {} (start elements: {})" , jnp .sum (curr_mask ), curr_mask [0 , :10 ])
8887
8988 flipped = jnp .flip (curr_mask , axis = [1 ])
90- jax .debug .print (" Flipped Mask Fwd logic sum: {} (first 20 elements: {})" , jnp .sum (flipped ), flipped [0 , :20 ])
89+ jax .debug .print (" Flipped Mask (Diffusers Logic) sum: {} (start elements: {})" , jnp .sum (flipped ), flipped [0 , :10 ])
9190
92- regs = self .learnable_registers .value
93- jax .debug .print (" [MAXDIFFUSION] Connector Registers std: {std}, mean: {mean}, min: {min}" ,
94- std = jnp .std (regs ), mean = jnp .mean (regs ), min = jnp .min (regs ))
95-
9691 return orig_replace (self , hidden_states , attention_mask )
9792
9893Embeddings1DConnector ._replace_padded_with_learnable_registers = patched_replace
@@ -101,32 +96,20 @@ def patched_replace(self, hidden_states, attention_mask):
10196
10297orig_block_call = _BasicTransformerBlock1D .__call__
10398def patched_block_call (self , hidden_states , attention_mask = None , rotary_emb = None ):
104- jax .debug .print ("[MAXDIFFUSION] Block Input std: {std}, min: {min}, max: {max}" ,
105- std = jnp .std (hidden_states ), min = jnp .min (hidden_states ), max = jnp .max (hidden_states ))
99+ jax .debug .print ("\n [MAXDIFFUSION W] to_q std: {k}, to_q bias: {b}" ,
100+ k = jnp .std (self .attn1 .to_q .kernel ), b = jnp .std (self .attn1 .to_q .bias ))
101+ jax .debug .print ("[MAXDIFFUSION W] to_k std: {k}, to_k bias: {b}" ,
102+ k = jnp .std (self .attn1 .to_k .kernel ), b = jnp .std (self .attn1 .to_k .bias ))
103+ jax .debug .print ("[MAXDIFFUSION W] to_v std: {k}, to_v bias: {b}" ,
104+ k = jnp .std (self .attn1 .to_v .kernel ), b = jnp .std (self .attn1 .to_v .bias ))
105+ jax .debug .print ("[MAXDIFFUSION W] to_out std: {k}, to_out bias: {b}" ,
106+ k = jnp .std (self .attn1 .to_out .kernel ), b = jnp .std (self .attn1 .to_out .bias ))
107+ jax .debug .print ("[MAXDIFFUSION W] norm_q std: {k}" , k = jnp .std (self .attn1 .norm_q .scale ))
106108
107- # 1. Norm -> Attention
108- normed1 = self .norm1 (hidden_states )
109- jax .debug .print (" [MAXDIFFUSION] norm1 std: {std}, min: {min}, max: {max}" ,
110- std = jnp .std (normed1 ), min = jnp .min (normed1 ), max = jnp .max (normed1 ))
109+ if attention_mask is not None :
110+ jax .debug .print ("[MAXDIFFUSION MASK] supplied to attention kernel sum: {}" , jnp .sum (attention_mask ))
111111
112- attn_output = self .attn1 (normed1 , attention_mask = attention_mask , rotary_emb = rotary_emb )
113- jax .debug .print (" [MAXDIFFUSION] attn1 std: {std}, min: {min}, max: {max}" ,
114- std = jnp .std (attn_output ), min = jnp .min (attn_output ), max = jnp .max (attn_output ))
115-
116- hidden_states = hidden_states + attn_output
117-
118- # 2. Norm -> FeedForward
119- normed2 = self .norm2 (hidden_states )
120- jax .debug .print (" [MAXDIFFUSION] norm2 std: {std}, min: {min}, max: {max}" ,
121- std = jnp .std (normed2 ), min = jnp .min (normed2 ), max = jnp .max (normed2 ))
122-
123- ff_output = self .ff (normed2 )
124- jax .debug .print (" [MAXDIFFUSION] ff std: {std}, min: {min}, max: {max}" ,
125- std = jnp .std (ff_output ), min = jnp .min (ff_output ), max = jnp .max (ff_output ))
126-
127- hidden_states = hidden_states + ff_output
128-
129- return hidden_states
112+ return orig_block_call (self , hidden_states , attention_mask = attention_mask , rotary_emb = rotary_emb )
130113
131114_BasicTransformerBlock1D .__call__ = patched_block_call
132115
0 commit comments