@@ -182,18 +182,20 @@ def __call__(
182182 ) -> Tuple [Array , Array ]:
183183
184184 # Debug print 1: Start
185- print (f"\\ nDEBUG: Embeddings1DConnector Start. hidden_states shape: { hidden_states .shape } " )
186- _t_np = jax .device_get (hidden_states )
187- print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
185+ jax .debug .print ("\\ nDEBUG: Embeddings1DConnector Start. hidden_states shape: {}" , hidden_states .shape )
186+ jax .debug .print (" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}" ,
187+ min = jnp .min (hidden_states ), max = jnp .max (hidden_states ),
188+ mean = jnp .mean (hidden_states ), std = jnp .std (hidden_states ))
188189
189190 # 1. Thinking Tokens
190191 if self .num_learnable_registers > 0 and attention_mask is not None :
191192 hidden_states , attention_mask = self ._replace_padded_with_learnable_registers (hidden_states , attention_mask )
192193
193194 # Debug print 2: After Padding Replacement
194- print (f"DEBUG: After replacing padded with registers. hidden_states shape: { hidden_states .shape } " )
195- _t_np = jax .device_get (hidden_states )
196- print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
195+ jax .debug .print ("DEBUG: After replacing padded with registers." )
196+ jax .debug .print (" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}" ,
197+ min = jnp .min (hidden_states ), max = jnp .max (hidden_states ),
198+ mean = jnp .mean (hidden_states ), std = jnp .std (hidden_states ))
197199
198200 # 2. RoPE
199201 seq_len = hidden_states .shape [1 ]
@@ -217,16 +219,18 @@ def block_scan_fn(carry, block_module):
217219 )(hidden_states , self .stacked_blocks )
218220
219221 # Debug print 3: After scan
220- print (f"DEBUG: After transformer blocks scan. hidden_states shape: { hidden_states .shape } " )
221- _t_np = jax .device_get (hidden_states )
222- print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
222+ jax .debug .print ("DEBUG: After transformer blocks scan." )
223+ jax .debug .print (" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}" ,
224+ min = jnp .min (hidden_states ), max = jnp .max (hidden_states ),
225+ mean = jnp .mean (hidden_states ), std = jnp .std (hidden_states ))
223226
224227 # 4. Final Norm
225228 hidden_states = self .final_norm (hidden_states )
226229
227230 # Debug print 4: Final Norm
228- print (f"DEBUG: After final norm. hidden_states shape: { hidden_states .shape } " )
229- _t_np = jax .device_get (hidden_states )
230- print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
231+ jax .debug .print ("DEBUG: After final norm." )
232+ jax .debug .print (" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}" ,
233+ min = jnp .min (hidden_states ), max = jnp .max (hidden_states ),
234+ mean = jnp .mean (hidden_states ), std = jnp .std (hidden_states ))
231235
232236 return hidden_states , attention_mask
0 commit comments