Skip to content

Commit d3aedd0

Browse files
committed
testing torchax modifed kernel.
1 parent ae22683 commit d3aedd0

2 files changed

Lines changed: 741 additions & 56 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 80 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from einops import rearrange
3232
from .. import common_types, max_logging
3333

34+
from . import custom_splash_attention as custom_splash
35+
3436
from . import quantizations
3537
from .modeling_flax_utils import get_activation
3638

@@ -311,7 +313,7 @@ def _tpu_flash_attention(
311313
flash_block_sizes: BlockSizes,
312314
dtype: jnp.dtype = jnp.float32,
313315
attention_kernel: str = "flash",
314-
mask_padding_tokens: bool = True,
316+
mask_padding_tokens: bool = False,
315317
residual_checkpoint_name: str | None = None,
316318
attention_mask: jax.Array = None,
317319
) -> jax.Array:
@@ -334,31 +336,42 @@ def _tpu_flash_attention(
334336
check_rep=False,
335337
)
336338
def wrap_flash_attention(query, key, value):
337-
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
338-
block_q_sizes = (
339-
block_sizes.block_q,
340-
block_sizes.block_q_dkv,
341-
)
342-
block_kv_sizes = (
343-
block_sizes.block_kv,
344-
block_sizes.block_kv_dkv,
339+
bq = 2048
340+
bkv=2048
341+
bkv_compute = 1024
342+
bkv_compute_in = 256
343+
heads_per_tile = 1 # Matches Torchax default
344+
# uses_fused_kernel = block_sizes.use_fused_bwd_kernel
345+
# block_q_sizes = (
346+
# block_sizes.block_q,
347+
# block_sizes.block_q_dkv,
348+
# )
349+
# block_kv_sizes = (
350+
# block_sizes.block_kv,
351+
# block_sizes.block_kv_dkv,
352+
# )
353+
# if uses_fused_kernel:
354+
# block_q_sizes += (block_sizes.block_q_dkv,)
355+
# block_kv_sizes += (block_sizes.block_kv_dkv,)
356+
# else:
357+
# block_q_sizes += (block_sizes.block_q_dq,)
358+
# block_kv_sizes += (block_sizes.block_kv_dq,)
359+
360+
# block_q = max(*block_q_sizes)
361+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
362+
363+
#block_kv = max(*block_kv_sizes)
364+
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
365+
value, _, _ = _pad_data_for_flash(value, heads, bkv)
366+
367+
bsizes = custom_splash._BlockSizes(
368+
block_q=bq,
369+
block_kv=bkv,
370+
block_kv_compute=bkv_compute,
345371
)
346-
if uses_fused_kernel:
347-
block_q_sizes += (block_sizes.block_q_dkv,)
348-
block_kv_sizes += (block_sizes.block_kv_dkv,)
349-
else:
350-
block_q_sizes += (block_sizes.block_q_dq,)
351-
block_kv_sizes += (block_sizes.block_kv_dq,)
352372

353-
block_q = max(*block_q_sizes)
354-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
355-
356-
block_kv = max(*block_kv_sizes)
357-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
358-
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
359-
360-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
361-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
373+
# mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
374+
# multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
362375

363376
q_padded_len = query.shape[2]
364377
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
@@ -369,24 +382,25 @@ def wrap_flash_attention(query, key, value):
369382
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
370383

371384
# If attention_mask is provided, apply it to kv_segment_ids
372-
if attention_mask is not None:
373-
mask_len = min(key_seq_len, attention_mask.shape[1])
374-
kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
375-
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
376-
if key_seq_len > mask_len:
377-
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
378-
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
379-
# Pad to kv_padded_len
380-
if kv_padded_len > key_seq_len:
381-
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
382-
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
383-
else:
384-
kv_mask_padded = kv_mask_for_batch
385-
# Both are (kv_padded_len,) - element-wise multiplication
386-
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
385+
# if attention_mask is not None:
386+
# mask_len = min(key_seq_len, attention_mask.shape[1])
387+
# kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
388+
# # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
389+
# if key_seq_len > mask_len:
390+
# extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
391+
# kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
392+
# # Pad to kv_padded_len
393+
# if kv_padded_len > key_seq_len:
394+
# padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
395+
# kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
396+
# else:
397+
# kv_mask_padded = kv_mask_for_batch
398+
# # Both are (kv_padded_len,) - element-wise multiplication
399+
# kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
387400

388401
if attention_kernel == "tokamax_ring":
389-
segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
402+
#segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
403+
pass
390404
else:
391405
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
392406

@@ -403,18 +417,26 @@ def wrap_flash_attention(query, key, value):
403417
save_residuals=False,
404418
)
405419
elif attention_kernel == "tokamax_ring":
406-
mask = tokamax_splash_attention_mask.FullMask(
407-
_shape=(query.shape[2], key.shape[2]),
408-
)
409-
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
410-
mask=mask,
411-
is_mqa=False,
412-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
413-
save_residuals=False,
414-
ring_axis="context",
415-
rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
420+
# mask = tokamax_splash_attention_mask.FullMask(
421+
# _shape=(query.shape[2], key.shape[2]),
422+
# )
423+
# splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
424+
# mask=mask,
425+
# is_mqa=False,
426+
# config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
427+
# save_residuals=False,
428+
# ring_axis="context",
429+
# rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
430+
# )
431+
splash_kernel = custom_splash.make_splash_mha(
432+
block_sizes=bsizes,
433+
bkv_compute_in=bkv_compute_in,
434+
orig_q_seq_len=query_seq_len,
435+
orig_kv_seq_len=key_seq_len,
436+
heads_per_tile=heads_per_tile
416437
)
417438
else:
439+
splash_kernel = custom_splash
418440
splash_kernel = splash_attention_kernel.make_splash_mha(
419441
mask=multi_head_mask,
420442
head_shards=1, # the sizes of the axis is sharding over heads
@@ -424,12 +446,14 @@ def wrap_flash_attention(query, key, value):
424446
residual_checkpoint_name=residual_checkpoint_name,
425447
)
426448

427-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
449+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0), out_axes=0)
428450

429451
if not mask_padding_tokens:
430452
segment_ids = None
431453
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
432-
attention_output = vmapped_splash(query, key, value, segment_ids)
454+
attention_output = vmapped_splash(query, key, value)
455+
if attention_kernel == "tokamax_ring":
456+
attention_output = jnp.swapaxes(attention_output, 2, 3)
433457
else:
434458
if num_context_shards > 1:
435459
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
@@ -504,7 +528,7 @@ def _ulysses_attention(
504528
axis_names_kv: AxisNames,
505529
flash_block_sizes: BlockSizes,
506530
dtype: jnp.dtype = jnp.float32,
507-
mask_padding_tokens: bool = True,
531+
mask_padding_tokens: bool = False,
508532
residual_checkpoint_name: str | None = None,
509533
attention_mask: jax.Array = None,
510534
) -> jax.Array:
@@ -738,7 +762,7 @@ def _apply_attention(
738762
axis_names_kv: AxisNames,
739763
flash_block_sizes: BlockSizes,
740764
dpa_layer: Callable,
741-
mask_padding_tokens: bool = True,
765+
mask_padding_tokens: bool = False,
742766
residual_checkpoint_name: str | None = None,
743767
attention_mask: Array = None,
744768
):
@@ -981,7 +1005,7 @@ def __init__(
9811005
flash_block_sizes: BlockSizes = None,
9821006
dtype: DType = jnp.float32,
9831007
quant: Quant = None,
984-
mask_padding_tokens: bool = True,
1008+
mask_padding_tokens: bool = False,
9851009
residual_checkpoint_name: str | None = None,
9861010
):
9871011
self.dpa_layer = None
@@ -1139,7 +1163,7 @@ def __init__(
11391163
qkv_bias: bool = False,
11401164
quant: Quant = None,
11411165
is_self_attention: bool = True,
1142-
mask_padding_tokens: bool = True,
1166+
mask_padding_tokens: bool = False,
11431167
residual_checkpoint_name: str | None = None,
11441168
enable_jax_named_scopes: bool = False,
11451169
added_kv_proj_dim: Optional[int] = None, # New for I2V

0 commit comments

Comments
 (0)