Skip to content

Commit 82260cc

Browse files
committed
Trying text_mask 4
1 parent 4907d08 commit 82260cc

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,17 @@ def wrap_flash_attention(query, key, value):
301301
# If attention_mask is provided, apply it to kv_segment_ids
302302
# attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303303
if attention_mask is not None:
304+
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask shape: {}", attention_mask.shape)
305+
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask[0] sum: {}", attention_mask[0].sum())
306+
if attention_mask.shape[0] > 1:
307+
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask[1] sum: {}", attention_mask[1].sum())
308+
jax.debug.print("[DEBUG _tpu_flash_attention] key shape: {}, key_seq_len: {}, kv_padded_len: {}",
309+
key.shape, key_seq_len, kv_padded_len)
304310
# Take the first item since padding pattern is same across batch (especially with CFG)
305311
# This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
306312
mask_len = min(key_seq_len, attention_mask.shape[1])
307313
kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
314+
jax.debug.print("[DEBUG _tpu_flash_attention] Using attention_mask[0], sum: {}", kv_mask_for_batch.sum())
308315
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
309316
if key_seq_len > mask_len:
310317
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
@@ -318,6 +325,7 @@ def wrap_flash_attention(query, key, value):
318325
# Combine with existing kv_segment_ids (which handles block alignment padding)
319326
# Both are (kv_padded_len,) - element-wise multiplication
320327
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
328+
jax.debug.print("[DEBUG _tpu_flash_attention] Final kv_segment_ids sum: {}", kv_segment_ids.sum())
321329

322330
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
323331

0 commit comments

Comments
 (0)