Skip to content

Commit 1fbfd5b

Browse files
committed
Trying text_mask 7
1 parent e7bd680 commit 1fbfd5b

1 file changed

Lines changed: 0 additions & 9 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,6 @@ def wrap_flash_attention(query, key, value):
301301
# If attention_mask is provided, apply it to kv_segment_ids
302302
# attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303303
if attention_mask is not None:
304-
# DEBUG: Check shapes and values
305-
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask shape: {}", attention_mask.shape)
306-
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask[0] sum: {}", attention_mask[0].sum())
307-
if attention_mask.shape[0] > 1:
308-
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask[1] sum: {}", attention_mask[1].sum())
309-
jax.debug.print("[DEBUG _tpu_flash_attention] key shape: {}, key_seq_len: {}, kv_padded_len: {}",
310-
key.shape, key_seq_len, kv_padded_len)
311304

312305
# For CFG, different batch items have different padding patterns (pos vs neg prompts)
313306
# We need a per-batch mask, but segment_ids need to be 1D for vmapped_splash
@@ -316,7 +309,6 @@ def wrap_flash_attention(query, key, value):
316309
mask_len = min(key_seq_len, attention_mask.shape[1])
317310
kv_mask_any = jnp.max(attention_mask[:, :mask_len], axis=0).astype(jnp.int32) # (mask_len,)
318311

319-
jax.debug.print("[DEBUG _tpu_flash_attention] Using OR across batch, sum: {}", kv_mask_any.sum())
320312

321313
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
322314
if key_seq_len > mask_len:
@@ -332,7 +324,6 @@ def wrap_flash_attention(query, key, value):
332324
# Both are (kv_padded_len,) - element-wise multiplication
333325
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
334326

335-
jax.debug.print("[DEBUG _tpu_flash_attention] Final kv_segment_ids sum: {}", kv_segment_ids.sum())
336327

337328
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
338329

0 commit comments

Comments
 (0)