Skip to content

Commit d88dd43

Browse files
committed
Trying text_mask 5
1 parent 82260cc commit d88dd43

1 file changed

Lines changed: 14 additions & 7 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,30 +301,37 @@ def wrap_flash_attention(query, key, value):
301301
# If attention_mask is provided, apply it to kv_segment_ids
302302
# attention_mask shape: (B, original_kv_seq_len) with 1 for real tokens, 0 for padded
303303
if attention_mask is not None:
304+
# DEBUG: Check shapes and values
304305
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask shape: {}", attention_mask.shape)
305306
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask[0] sum: {}", attention_mask[0].sum())
306307
if attention_mask.shape[0] > 1:
307308
jax.debug.print("[DEBUG _tpu_flash_attention] attention_mask[1] sum: {}", attention_mask[1].sum())
308309
jax.debug.print("[DEBUG _tpu_flash_attention] key shape: {}, key_seq_len: {}, kv_padded_len: {}",
309310
key.shape, key_seq_len, kv_padded_len)
310-
# Take the first item since padding pattern is same across batch (especially with CFG)
311-
# This keeps kv_segment_ids as (kv_padded_len,) for compatibility with vmapped_splash
311+
312+
# For CFG, different batch items have different padding patterns (pos vs neg prompts)
313+
# We need a per-batch mask, but segment_ids need to be 1D for vmapped_splash
314+
# Solution: Use logical OR across batch - a position is valid if ANY batch item needs it
315+
# This is safe because zero embeddings at unneeded positions don't affect attention output
312316
mask_len = min(key_seq_len, attention_mask.shape[1])
313-
kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
314-
jax.debug.print("[DEBUG _tpu_flash_attention] Using attention_mask[0], sum: {}", kv_mask_for_batch.sum())
317+
kv_mask_any = jnp.max(attention_mask[:, :mask_len], axis=0).astype(jnp.int32) # (mask_len,)
318+
319+
jax.debug.print("[DEBUG _tpu_flash_attention] Using OR across batch, sum: {}", kv_mask_any.sum())
320+
315321
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
316322
if key_seq_len > mask_len:
317323
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
318-
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
324+
kv_mask_any = jnp.concatenate([kv_mask_any, extra_valid], axis=0) # (key_seq_len,)
319325
# Pad to kv_padded_len
320326
if kv_padded_len > key_seq_len:
321327
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
322-
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
328+
kv_mask_padded = jnp.concatenate([kv_mask_any, padding], axis=0) # (kv_padded_len,)
323329
else:
324-
kv_mask_padded = kv_mask_for_batch
330+
kv_mask_padded = kv_mask_any
325331
# Combine with existing kv_segment_ids (which handles block alignment padding)
326332
# Both are (kv_padded_len,) - element-wise multiplication
327333
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
334+
328335
jax.debug.print("[DEBUG _tpu_flash_attention] Final kv_segment_ids sum: {}", kv_segment_ids.sum())
329336

330337
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)

0 commit comments

Comments
 (0)