Skip to content

Commit e5c6324

Browse files
committed
updated comments
1 parent ef6ead2 commit e5c6324

1 file changed

Lines changed: 2 additions & 6 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,7 @@ def wrap_flash_attention(query, key, value):
298298
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
299299

300300
# If attention_mask is provided, apply it to kv_segment_ids
301-
# attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
302301
if attention_mask is not None:
303-
# Take the first item since padding pattern is same across batch (especially with CFG)
304-
# This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
305302
mask_len = min(key_seq_len, attention_mask.shape[1])
306303
kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
307304
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
@@ -314,7 +311,6 @@ def wrap_flash_attention(query, key, value):
314311
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
315312
else:
316313
kv_mask_padded = kv_mask_for_batch
317-
# Combine with existing kv_segment_ids (which handles block alignment padding)
318314
# Both are (kv_padded_len,) - element-wise multiplication
319315
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
320316

@@ -1101,12 +1097,12 @@ def __call__(
11011097
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
11021098
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
11031099

1104-
# Use the passed encoder_attention_mask (created in embeddings_flax.py)
1100+
# Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention
11051101
# It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
11061102
if encoder_attention_mask is not None:
11071103
encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
11081104
else:
1109-
# Fallback: no mask means treat all as valid
1105+
# Fallback: no mask means treat all as valid (for dot product attention)
11101106
encoder_attention_mask_img = None
11111107
else:
11121108
# If no image_seq_len is specified, treat all as text

0 commit comments

Comments
 (0)