@@ -301,30 +301,37 @@ def wrap_flash_attention(query, key, value):
301301 # If attention_mask is provided, apply it to kv_segment_ids
302302 # attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303303 if attention_mask is not None :
304+ # DEBUG: Check shapes and values
304305 jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask shape: {}" , attention_mask .shape )
305306 jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask[0] sum: {}" , attention_mask [0 ].sum ())
306307 if attention_mask .shape [0 ] > 1 :
307308 jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask[1] sum: {}" , attention_mask [1 ].sum ())
308309 jax .debug .print ("[DEBUG _tpu_flash_attention] key shape: {}, key_seq_len: {}, kv_padded_len: {}" ,
309310 key .shape , key_seq_len , kv_padded_len )
310- # Take the first item since padding pattern is same across batch (especially with CFG)
311- # This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
311+
312+ # For CFG, different batch items have different padding patterns (pos vs neg prompts)
313+ # We need a per-batch mask, but segment_ids need to be 1D for vmapped_splash
314+ # Solution: Use logical OR across batch - a position is valid if ANY batch item needs it
315+ # This is safe because zero embeddings at unneeded positions don't affect attention output
312316 mask_len = min (key_seq_len , attention_mask .shape [1 ])
313- kv_mask_for_batch = attention_mask [0 , :mask_len ] # (mask_len,)
314- jax .debug .print ("[DEBUG _tpu_flash_attention] Using attention_mask[0], sum: {}" , kv_mask_for_batch .sum ())
317+ kv_mask_any = jnp .max (attention_mask [:, :mask_len ], axis = 0 ).astype (jnp .int32 ) # (mask_len,)
318+
319+ jax .debug .print ("[DEBUG _tpu_flash_attention] Using OR across batch, sum: {}" , kv_mask_any .sum ())
320+
315321 # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
316322 if key_seq_len > mask_len :
317323 extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
318- kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 ) # (key_seq_len,)
324+ kv_mask_any = jnp .concatenate ([kv_mask_any , extra_valid ], axis = 0 ) # (key_seq_len,)
319325 # Pad to kv_padded_len
320326 if kv_padded_len > key_seq_len :
321327 padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
322- kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 ) # (kv_padded_len,)
328+ kv_mask_padded = jnp .concatenate ([kv_mask_any , padding ], axis = 0 ) # (kv_padded_len,)
323329 else :
324- kv_mask_padded = kv_mask_for_batch
330+ kv_mask_padded = kv_mask_any
325331 # Combine with existing kv_segment_ids (which handles block alignment padding)
326332 # Both are (kv_padded_len,) - element-wise multiplication
327333 kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
334+
328335 jax .debug .print ("[DEBUG _tpu_flash_attention] Final kv_segment_ids sum: {}" , kv_segment_ids .sum ())
329336
330337 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
0 commit comments