@@ -299,10 +299,7 @@ def wrap_flash_attention(query, key, value):
299299 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
300300
301301 # If attention_mask is provided, apply it to kv_segment_ids
302- # attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303302 if attention_mask is not None :
304- # Take the first item since padding pattern is same across batch (especially with CFG)
305- # This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
306303 mask_len = min (key_seq_len , attention_mask .shape [1 ])
307304 kv_mask_for_batch = attention_mask [0 , :mask_len ] # (mask_len,)
308305 # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
@@ -315,7 +312,6 @@ def wrap_flash_attention(query, key, value):
315312 kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 ) # (kv_padded_len,)
316313 else :
317314 kv_mask_padded = kv_mask_for_batch
318- # Combine with existing kv_segment_ids (which handles block alignment padding)
319315 # Both are (kv_padded_len,) - element-wise multiplication
320316 kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
321317
@@ -1133,12 +1129,12 @@ def __call__(
11331129 encoder_hidden_states_img = encoder_hidden_states [:, :padded_img_len , :]
11341130 encoder_hidden_states_text = encoder_hidden_states [:, padded_img_len :, :]
11351131
1136- # Use the passed encoder_attention_mask (created in embeddings_flax.py)
1132+ # Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention
11371133 # It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
11381134 if encoder_attention_mask is not None :
11391135 encoder_attention_mask_img = encoder_attention_mask [:, :padded_img_len ]
11401136 else :
1141- # Fallback: no mask means treat all as valid
1137+ # Fallback: no mask means treat all as valid (for dot product attention)
11421138 encoder_attention_mask_img = None
11431139 else :
11441140 # If no image_seq_len is specified, treat all as text
0 commit comments