@@ -103,7 +103,30 @@ def patched_replace(self, hidden_states, attention_mask):
103103def patched_block_call (self , hidden_states , attention_mask = None , rotary_emb = None ):
104104 jax .debug .print ("[MAXDIFFUSION] Block Input std: {std}, min: {min}, max: {max}" ,
105105 std = jnp .std (hidden_states ), min = jnp .min (hidden_states ), max = jnp .max (hidden_states ))
106- return orig_block_call (self , hidden_states , attention_mask = attention_mask , rotary_emb = rotary_emb )
106+
107+ # 1. Norm -> Attention
108+ normed1 = self .norm1 (hidden_states )
109+ jax .debug .print (" [MAXDIFFUSION] norm1 std: {std}, min: {min}, max: {max}" ,
110+ std = jnp .std (normed1 ), min = jnp .min (normed1 ), max = jnp .max (normed1 ))
111+
112+ attn_output = self .attn1 (normed1 , attention_mask = attention_mask , rotary_emb = rotary_emb )
113+ jax .debug .print (" [MAXDIFFUSION] attn1 std: {std}, min: {min}, max: {max}" ,
114+ std = jnp .std (attn_output ), min = jnp .min (attn_output ), max = jnp .max (attn_output ))
115+
116+ hidden_states = hidden_states + attn_output
117+
118+ # 2. Norm -> FeedForward
119+ normed2 = self .norm2 (hidden_states )
120+ jax .debug .print (" [MAXDIFFUSION] norm2 std: {std}, min: {min}, max: {max}" ,
121+ std = jnp .std (normed2 ), min = jnp .min (normed2 ), max = jnp .max (normed2 ))
122+
123+ ff_output = self .ff (normed2 )
124+ jax .debug .print (" [MAXDIFFUSION] ff std: {std}, min: {min}, max: {max}" ,
125+ std = jnp .std (ff_output ), min = jnp .min (ff_output ), max = jnp .max (ff_output ))
126+
127+ hidden_states = hidden_states + ff_output
128+
129+ return hidden_states
107130
108131_BasicTransformerBlock1D .__call__ = patched_block_call
109132
0 commit comments