Skip to content

Commit dcef418

Browse files
committed
updated comments
1 parent 8c1cd6d commit dcef418

1 file changed

Lines changed: 2 additions & 6 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,7 @@ def wrap_flash_attention(query, key, value):
299299
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
300300

301301
# If attention_mask is provided, apply it to kv_segment_ids
302-
# attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303302
if attention_mask is not None:
304-
# Take the first item since padding pattern is same across batch (especially with CFG)
305-
# This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
306303
mask_len = min(key_seq_len, attention_mask.shape[1])
307304
kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
308305
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
@@ -315,7 +312,6 @@ def wrap_flash_attention(query, key, value):
315312
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
316313
else:
317314
kv_mask_padded = kv_mask_for_batch
318-
# Combine with existing kv_segment_ids (which handles block alignment padding)
319315
# Both are (kv_padded_len,) - element-wise multiplication
320316
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
321317

@@ -1133,12 +1129,12 @@ def __call__(
11331129
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
11341130
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
11351131

1136-
# Use the passed encoder_attention_mask (created in embeddings_flax.py)
1132+
# Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention
11371133
# It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
11381134
if encoder_attention_mask is not None:
11391135
encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
11401136
else:
1141-
# Fallback: no mask means treat all as valid
1137+
# Fallback: no mask means treat all as valid (for dot product attention)
11421138
encoder_attention_mask_img = None
11431139
else:
11441140
# If no image_seq_len is specified, treat all as text

0 commit comments

Comments
 (0)