Skip to content

Commit e9d0620

Browse files
entrpneltsai
authored andcommitted
testing torchax modifed kernel.
1 parent c98002f commit e9d0620

2 files changed

Lines changed: 730 additions & 60 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from einops import rearrange
3232
from .. import common_types, max_logging
3333

34+
from . import custom_splash_attention as custom_splash
35+
3436
from . import quantizations
3537
from .modeling_flax_utils import get_activation
3638

@@ -313,7 +315,7 @@ def _tpu_flash_attention(
313315
flash_block_sizes: BlockSizes,
314316
dtype: jnp.dtype = jnp.float32,
315317
attention_kernel: str = "flash",
316-
mask_padding_tokens: bool = True,
318+
mask_padding_tokens: bool = False,
317319
residual_checkpoint_name: str | None = None,
318320
attention_mask: jax.Array = None,
319321
use_base2_exp: bool = False,
@@ -338,31 +340,42 @@ def _tpu_flash_attention(
338340
check_rep=False,
339341
)
340342
def wrap_flash_attention(query, key, value):
341-
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
342-
block_q_sizes = (
343-
block_sizes.block_q,
344-
block_sizes.block_q_dkv,
345-
)
346-
block_kv_sizes = (
347-
block_sizes.block_kv,
348-
block_sizes.block_kv_dkv,
343+
bq = 2048
344+
bkv=2048
345+
bkv_compute = 1024
346+
bkv_compute_in = 256
347+
heads_per_tile = 1 # Matches Torchax default
348+
# uses_fused_kernel = block_sizes.use_fused_bwd_kernel
349+
# block_q_sizes = (
350+
# block_sizes.block_q,
351+
# block_sizes.block_q_dkv,
352+
# )
353+
# block_kv_sizes = (
354+
# block_sizes.block_kv,
355+
# block_sizes.block_kv_dkv,
356+
# )
357+
# if uses_fused_kernel:
358+
# block_q_sizes += (block_sizes.block_q_dkv,)
359+
# block_kv_sizes += (block_sizes.block_kv_dkv,)
360+
# else:
361+
# block_q_sizes += (block_sizes.block_q_dq,)
362+
# block_kv_sizes += (block_sizes.block_kv_dq,)
363+
364+
# block_q = max(*block_q_sizes)
365+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
366+
367+
#block_kv = max(*block_kv_sizes)
368+
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
369+
value, _, _ = _pad_data_for_flash(value, heads, bkv)
370+
371+
bsizes = custom_splash._BlockSizes(
372+
block_q=bq,
373+
block_kv=bkv,
374+
block_kv_compute=bkv_compute,
349375
)
350-
if uses_fused_kernel:
351-
block_q_sizes += (block_sizes.block_q_dkv,)
352-
block_kv_sizes += (block_sizes.block_kv_dkv,)
353-
else:
354-
block_q_sizes += (block_sizes.block_q_dq,)
355-
block_kv_sizes += (block_sizes.block_kv_dq,)
356376

357-
block_q = max(*block_q_sizes)
358-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
359-
360-
block_kv = max(*block_kv_sizes)
361-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
362-
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
363-
364-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
365-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
377+
# mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
378+
# multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
366379

367380
q_padded_len = query.shape[2]
368381
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
@@ -373,24 +386,25 @@ def wrap_flash_attention(query, key, value):
373386
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
374387

375388
# If attention_mask is provided, apply it to kv_segment_ids
376-
if attention_mask is not None:
377-
mask_len = min(key_seq_len, attention_mask.shape[1])
378-
kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
379-
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
380-
if key_seq_len > mask_len:
381-
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
382-
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
383-
# Pad to kv_padded_len
384-
if kv_padded_len > key_seq_len:
385-
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
386-
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
387-
else:
388-
kv_mask_padded = kv_mask_for_batch
389-
# Both are (kv_padded_len,) - element-wise multiplication
390-
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
389+
# if attention_mask is not None:
390+
# mask_len = min(key_seq_len, attention_mask.shape[1])
391+
# kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
392+
# # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
393+
# if key_seq_len > mask_len:
394+
# extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
395+
# kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
396+
# # Pad to kv_padded_len
397+
# if kv_padded_len > key_seq_len:
398+
# padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
399+
# kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
400+
# else:
401+
# kv_mask_padded = kv_mask_for_batch
402+
# # Both are (kv_padded_len,) - element-wise multiplication
403+
# kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
391404

392405
if attention_kernel == "tokamax_ring":
393-
segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
406+
#segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
407+
pass
394408
else:
395409
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
396410

@@ -412,23 +426,16 @@ def wrap_flash_attention(query, key, value):
412426
save_residuals=False,
413427
)
414428
elif attention_kernel == "tokamax_ring":
415-
mask = tokamax_splash_attention_mask.FullMask(
416-
_shape=(query.shape[2], key.shape[2]),
429+
splash_kernel = custom_splash.make_splash_mha(
430+
block_sizes=bsizes,
431+
bkv_compute_in=bkv_compute_in,
432+
orig_q_seq_len=query_seq_len,
433+
orig_kv_seq_len=key_seq_len,
434+
heads_per_tile=heads_per_tile,
417435
)
418-
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
419-
mask=mask,
420-
is_mqa=False,
421-
config=convert_to_tokamax_splash_config(
422-
block_sizes,
423-
residual_checkpoint_name=residual_checkpoint_name,
424-
use_base2_exp=use_base2_exp,
425-
use_experimental_scheduler=use_experimental_scheduler,
426-
),
427-
save_residuals=False,
428-
ring_axis="context",
429-
rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
430436
)
431437
else:
438+
splash_kernel = custom_splash
432439
splash_kernel = splash_attention_kernel.make_splash_mha(
433440
mask=multi_head_mask,
434441
head_shards=1, # the sizes of the axis is sharding over heads
@@ -438,12 +445,14 @@ def wrap_flash_attention(query, key, value):
438445
residual_checkpoint_name=residual_checkpoint_name,
439446
)
440447

441-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
448+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0), out_axes=0)
442449

443450
if not mask_padding_tokens:
444451
segment_ids = None
445452
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
446-
attention_output = vmapped_splash(query, key, value, segment_ids)
453+
attention_output = vmapped_splash(query, key, value)
454+
if attention_kernel == "tokamax_ring":
455+
attention_output = jnp.swapaxes(attention_output, 2, 3)
447456
else:
448457
if num_context_shards > 1:
449458
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
@@ -518,7 +527,7 @@ def _ulysses_attention(
518527
axis_names_kv: AxisNames,
519528
flash_block_sizes: BlockSizes,
520529
dtype: jnp.dtype = jnp.float32,
521-
mask_padding_tokens: bool = True,
530+
mask_padding_tokens: bool = False,
522531
residual_checkpoint_name: str | None = None,
523532
attention_mask: jax.Array = None,
524533
) -> jax.Array:
@@ -752,7 +761,7 @@ def _apply_attention(
752761
axis_names_kv: AxisNames,
753762
flash_block_sizes: BlockSizes,
754763
dpa_layer: Callable,
755-
mask_padding_tokens: bool = True,
764+
mask_padding_tokens: bool = False,
756765
residual_checkpoint_name: str | None = None,
757766
attention_mask: Array = None,
758767
use_base2_exp: bool = False,
@@ -999,7 +1008,7 @@ def __init__(
9991008
flash_block_sizes: BlockSizes = None,
10001009
dtype: DType = jnp.float32,
10011010
quant: Quant = None,
1002-
mask_padding_tokens: bool = True,
1011+
mask_padding_tokens: bool = False,
10031012
residual_checkpoint_name: str | None = None,
10041013
use_base2_exp: bool = False,
10051014
use_experimental_scheduler: bool = False,
@@ -1167,7 +1176,7 @@ def __init__(
11671176
qkv_bias: bool = False,
11681177
quant: Quant = None,
11691178
is_self_attention: bool = True,
1170-
mask_padding_tokens: bool = True,
1179+
mask_padding_tokens: bool = False,
11711180
residual_checkpoint_name: str | None = None,
11721181
enable_jax_named_scopes: bool = False,
11731182
added_kv_proj_dim: Optional[int] = None, # New for I2V

0 commit comments

Comments
 (0)