We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5a4bfac commit df7e8dcCopy full SHA for df7e8dc
1 file changed
src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py
@@ -42,10 +42,15 @@ def _norm_and_concat_padded_batch(
42
"""
43
b, t, d, l = encoded_text.shape
44
45
- # [B, T, 1, 1]
46
- mask = attention_mask[:, :, None, None]
47
-
+ # Calculate left-aligned padding mask identical to Diffusers `_pack_text_embeds`
+ # Diffusers padding side is "left" for Gemma text encoders.
48
sequence_lengths = jnp.sum(attention_mask, axis=-1)
+ token_indices = jnp.arange(t)[None, :]
49
+ start_indices = t - sequence_lengths[:, None]
50
+ mask = token_indices >= start_indices
51
+
52
+ # Broadcast to [B, T, 1, 1]
53
+ mask = mask[:, :, None, None]
54
55
eps = 1e-6
56
0 commit comments