@@ -132,9 +132,10 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
132132
133133 num_duplications = t // self .num_learnable_registers
134134 registers = jnp .tile (self .learnable_registers [...], (num_duplications , 1 ))
135- registers = jnp .expand_dims (registers , 0 )
136-
137- if attention_mask .ndim == 2 :
135+
136+ if attention_mask .ndim == 4 :
137+ mask = attention_mask .squeeze (1 ).squeeze (1 )
138+ elif attention_mask .ndim == 2 :
138139 mask = attention_mask
139140 else :
140141 mask = attention_mask .squeeze (- 1 ) # [B, T]
@@ -154,16 +155,15 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
154155 shifted_hidden_states = jnp .zeros_like (hidden_states )
155156 shifted_hidden_states = shifted_hidden_states .at [b_idx , target_indices , :].set (hidden_states )
156157
157- # Shift mask
158- shifted_mask = jnp .zeros_like (curr_mask )
159- shifted_mask = shifted_mask .at [b_idx , target_indices ].set (curr_mask )
160-
161158 # 2. Add Learnable Registers
162- # Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
163- output = jnp .where (shifted_mask [..., None ] == 1 , shifted_hidden_states , registers )
159+ # Where flipped_mask is 1, keep valid tokens. Where 0, insert registers.
160+ flipped_mask = jnp .flip (curr_mask , axis = [1 ])
161+ flipped_mask_expanded = flipped_mask [..., None ]
162+
163+ output = jnp .where (flipped_mask_expanded == 1 , shifted_hidden_states , registers )
164164
165- # Overwrite attention_mask with all-ones since padding is now filled with registers.
166- new_mask = jnp .ones_like (attention_mask )
165+ # Overwrite attention_mask with all-zeros since padding is now filled with registers.
166+ new_mask = jnp .zeros_like (attention_mask )
167167 return output , new_mask
168168
169169 def _compute_1d_rope (self , seq_len : int , dtype : DType ) -> Tuple [Array , Array ]:
0 commit comments