@@ -133,9 +133,7 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
133133 num_duplications = t // self .num_learnable_registers
134134 registers = jnp .tile (self .learnable_registers [...], (num_duplications , 1 ))
135135
136- if attention_mask .ndim == 4 :
137- mask = attention_mask .squeeze (1 ).squeeze (1 )
138- elif attention_mask .ndim == 2 :
136+ if attention_mask .ndim == 2 :
139137 mask = attention_mask
140138 else :
141139 mask = attention_mask .squeeze (- 1 ) # [B, T]
@@ -155,15 +153,16 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
155153 shifted_hidden_states = jnp .zeros_like (hidden_states )
156154 shifted_hidden_states = shifted_hidden_states .at [b_idx , target_indices , :].set (hidden_states )
157155
156+ # Shift mask
157+ shifted_mask = jnp .zeros_like (curr_mask )
158+ shifted_mask = shifted_mask .at [b_idx , target_indices ].set (curr_mask )
159+
158160 # 2. Add Learnable Registers
159- # Where flipped_mask is 1, keep valid tokens. Where 0, insert registers.
160- flipped_mask = jnp .flip (curr_mask , axis = [1 ])
161- flipped_mask_expanded = flipped_mask [..., None ]
162-
163- output = jnp .where (flipped_mask_expanded == 1 , shifted_hidden_states , registers )
161+ # Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
162+ output = jnp .where (shifted_mask [..., None ] == 1 , shifted_hidden_states , registers )
164163
165- # Overwrite attention_mask with all-zeros since padding is now filled with registers.
166- new_mask = jnp .zeros_like (attention_mask )
164+ # Overwrite attention_mask with all-ones since padding is now filled with registers.
165+ new_mask = jnp .ones_like (attention_mask )
167166 return output , new_mask
168167
169168 def _compute_1d_rope (self , seq_len : int , dtype : DType ) -> Tuple [Array , Array ]:
@@ -181,9 +180,20 @@ def __call__(
181180 hidden_states : Array ,
182181 attention_mask : Optional [Array ] = None ,
183182 ) -> Tuple [Array , Array ]:
183+
184+ # Debug print 1: Start
185+ print (f"\\ nDEBUG: Embeddings1DConnector Start. hidden_states shape: { hidden_states .shape } " )
186+ _t_np = jax .device_get (hidden_states )
187+ print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
188+
184189 # 1. Thinking Tokens
185190 if self .num_learnable_registers > 0 and attention_mask is not None :
186191 hidden_states , attention_mask = self ._replace_padded_with_learnable_registers (hidden_states , attention_mask )
192+
193+ # Debug print 2: After Padding Replacement
194+ print (f"DEBUG: After replacing padded with registers. hidden_states shape: { hidden_states .shape } " )
195+ _t_np = jax .device_get (hidden_states )
196+ print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
187197
188198 # 2. RoPE
189199 seq_len = hidden_states .shape [1 ]
@@ -205,8 +215,18 @@ def block_scan_fn(carry, block_module):
205215 in_axes = (nnx .Carry , 0 ), # Scan over the layers dimension (0) of block_module
206216 out_axes = (nnx .Carry , 0 ),
207217 )(hidden_states , self .stacked_blocks )
218+
219+ # Debug print 3: After scan
220+ print (f"DEBUG: After transformer blocks scan. hidden_states shape: { hidden_states .shape } " )
221+ _t_np = jax .device_get (hidden_states )
222+ print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
208223
209224 # 4. Final Norm
210225 hidden_states = self .final_norm (hidden_states )
211226
227+ # Debug print 4: Final Norm
228+ print (f"DEBUG: After final norm. hidden_states shape: { hidden_states .shape } " )
229+ _t_np = jax .device_get (hidden_states )
230+ print (f" min: { _t_np .min ():.5f} , max: { _t_np .max ():.5f} , mean: { _t_np .mean ():.5f} , std: { _t_np .std ():.5f} " )
231+
212232 return hidden_states , attention_mask
0 commit comments