@@ -301,10 +301,17 @@ def wrap_flash_attention(query, key, value):
301301 # If attention_mask is provided, apply it to kv_segment_ids
302302 # attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303303 if attention_mask is not None :
304+ jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask shape: {}" , attention_mask .shape )
305+ jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask[0] sum: {}" , attention_mask [0 ].sum ())
306+ if attention_mask .shape [0 ] > 1 :
307+ jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask[1] sum: {}" , attention_mask [1 ].sum ())
308+ jax .debug .print ("[DEBUG _tpu_flash_attention] key shape: {}, key_seq_len: {}, kv_padded_len: {}" ,
309+ key .shape , key_seq_len , kv_padded_len )
304310 # Take the first item since padding pattern is same across batch (especially with CFG)
305311 # This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
306312 mask_len = min (key_seq_len , attention_mask .shape [1 ])
307313 kv_mask_for_batch = attention_mask [0 , :mask_len ] # (mask_len,)
314+ jax .debug .print ("[DEBUG _tpu_flash_attention] Using attention_mask[0], sum: {}" , kv_mask_for_batch .sum ())
308315 # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
309316 if key_seq_len > mask_len :
310317 extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
@@ -318,6 +325,7 @@ def wrap_flash_attention(query, key, value):
318325 # Combine with existing kv_segment_ids (which handles block alignment padding)
319326 # Both are (kv_padded_len,) - element-wise multiplication
320327 kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
328+ jax .debug .print ("[DEBUG _tpu_flash_attention] Final kv_segment_ids sum: {}" , kv_segment_ids .sum ())
321329
322330 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
323331
0 commit comments