@@ -301,13 +301,6 @@ def wrap_flash_attention(query, key, value):
301301 # If attention_mask is provided, apply it to kv_segment_ids
302302 # attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303303 if attention_mask is not None :
304- # DEBUG: Check shapes and values
305- jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask shape: {}" , attention_mask .shape )
306- jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask[0] sum: {}" , attention_mask [0 ].sum ())
307- if attention_mask .shape [0 ] > 1 :
308- jax .debug .print ("[DEBUG _tpu_flash_attention] attention_mask[1] sum: {}" , attention_mask [1 ].sum ())
309- jax .debug .print ("[DEBUG _tpu_flash_attention] key shape: {}, key_seq_len: {}, kv_padded_len: {}" ,
310- key .shape , key_seq_len , kv_padded_len )
311304
312305 # For CFG, different batch items have different padding patterns (pos vs neg prompts)
313306 # We need a per-batch mask, but segment_ids need to be 1D for vmapped_splash
@@ -316,7 +309,6 @@ def wrap_flash_attention(query, key, value):
316309 mask_len = min (key_seq_len , attention_mask .shape [1 ])
317310 kv_mask_any = jnp .max (attention_mask [:, :mask_len ], axis = 0 ).astype (jnp .int32 ) # (mask_len,)
318311
319- jax .debug .print ("[DEBUG _tpu_flash_attention] Using OR across batch, sum: {}" , kv_mask_any .sum ())
320312
321313 # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
322314 if key_seq_len > mask_len :
@@ -332,7 +324,6 @@ def wrap_flash_attention(query, key, value):
332324 # Both are (kv_padded_len,) - element-wise multiplication
333325 kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
334326
335- jax .debug .print ("[DEBUG _tpu_flash_attention] Final kv_segment_ids sum: {}" , kv_segment_ids .sum ())
336327
337328 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
338329
0 commit comments