Skip to content

Commit df7e8dc

Browse files
committed
change in left right padding in feature_extractor.py
1 parent 5a4bfac commit df7e8dc

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,15 @@ def _norm_and_concat_padded_batch(
4242
"""
4343
b, t, d, l = encoded_text.shape
4444

45-
# [B, T, 1, 1]
46-
mask = attention_mask[:, :, None, None]
47-
45+
# Calculate left-aligned padding mask identical to Diffusers `_pack_text_embeds`
46+
# Diffusers padding side is "left" for Gemma text encoders.
4847
sequence_lengths = jnp.sum(attention_mask, axis=-1)
48+
token_indices = jnp.arange(t)[None, :]
49+
start_indices = t - sequence_lengths[:, None]
50+
mask = token_indices >= start_indices
51+
52+
# Broadcast to [B, T, 1, 1]
53+
mask = mask[:, :, None, None]
4954

5055
eps = 1e-6
5156

0 commit comments

Comments
 (0)