@@ -298,10 +298,7 @@ def wrap_flash_attention(query, key, value):
298298 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
299299
300300 # If attention_mask is provided, apply it to kv_segment_ids
301- # attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
302301 if attention_mask is not None :
303- # Take the first item since padding pattern is same across batch (especially with CFG)
304- # This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
305302 mask_len = min (key_seq_len , attention_mask .shape [1 ])
306303 kv_mask_for_batch = attention_mask [0 , :mask_len ] # (mask_len,)
307304 # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
@@ -314,7 +311,6 @@ def wrap_flash_attention(query, key, value):
314311 kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 ) # (kv_padded_len,)
315312 else :
316313 kv_mask_padded = kv_mask_for_batch
317- # Combine with existing kv_segment_ids (which handles block alignment padding)
318314 # Both are (kv_padded_len,) - element-wise multiplication
319315 kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
320316
@@ -1101,12 +1097,12 @@ def __call__(
11011097 encoder_hidden_states_img = encoder_hidden_states [:, :padded_img_len , :]
11021098 encoder_hidden_states_text = encoder_hidden_states [:, padded_img_len :, :]
11031099
1104- # Use the passed encoder_attention_mask (created in embeddings_flax.py)
1100+ # Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention
11051101 # It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
11061102 if encoder_attention_mask is not None :
11071103 encoder_attention_mask_img = encoder_attention_mask [:, :padded_img_len ]
11081104 else :
1109- # Fallback: no mask means treat all as valid
1105+ # Fallback: no mask means treat all as valid (for dot product attention)
11101106 encoder_attention_mask_img = None
11111107 else :
11121108 # If no image_seq_len is specified, treat all as text
0 commit comments