Skip to content

Commit 38fd3c7

Browse files
committed
fix embeddings masking
1 parent 0274615 commit 38fd3c7

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,9 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
168168
# Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
169169
output = jnp.where(shifted_mask[..., None] == 1, shifted_hidden_states, registers)
170170

171-
# Overwrite attention_mask with all-ones since padding is now filled with registers.
172-
new_mask = jnp.ones_like(attention_mask)
171+
# Padding has been filled with valid register tokens. The entire sequence
172+
# must now be attended to, so we clear the mask.
173+
new_mask = None
173174
return output, new_mask
174175

175176
def _compute_1d_rope(self, batch_size: int, seq_len: int, dtype: DType) -> Tuple[Array, Array]:

0 commit comments

Comments
 (0)