3131from einops import rearrange
3232from .. import common_types , max_logging
3333
34+ from . import custom_splash_attention as custom_splash
35+
3436from . import quantizations
3537from .modeling_flax_utils import get_activation
3638
@@ -313,7 +315,7 @@ def _tpu_flash_attention(
313315 flash_block_sizes : BlockSizes ,
314316 dtype : jnp .dtype = jnp .float32 ,
315317 attention_kernel : str = "flash" ,
316- mask_padding_tokens : bool = True ,
318+ mask_padding_tokens : bool = False ,
317319 residual_checkpoint_name : str | None = None ,
318320 attention_mask : jax .Array = None ,
319321 use_base2_exp : bool = False ,
@@ -338,31 +340,42 @@ def _tpu_flash_attention(
338340 check_rep = False ,
339341 )
340342 def wrap_flash_attention (query , key , value ):
341- uses_fused_kernel = block_sizes .use_fused_bwd_kernel
342- block_q_sizes = (
343- block_sizes .block_q ,
344- block_sizes .block_q_dkv ,
345- )
346- block_kv_sizes = (
347- block_sizes .block_kv ,
348- block_sizes .block_kv_dkv ,
343+ bq = 2048
344+ bkv = 2048
345+ bkv_compute = 1024
346+ bkv_compute_in = 256
347+ heads_per_tile = 1 # Matches Torchax default
348+ # uses_fused_kernel = block_sizes.use_fused_bwd_kernel
349+ # block_q_sizes = (
350+ # block_sizes.block_q,
351+ # block_sizes.block_q_dkv,
352+ # )
353+ # block_kv_sizes = (
354+ # block_sizes.block_kv,
355+ # block_sizes.block_kv_dkv,
356+ # )
357+ # if uses_fused_kernel:
358+ # block_q_sizes += (block_sizes.block_q_dkv,)
359+ # block_kv_sizes += (block_sizes.block_kv_dkv,)
360+ # else:
361+ # block_q_sizes += (block_sizes.block_q_dq,)
362+ # block_kv_sizes += (block_sizes.block_kv_dq,)
363+
364+ # block_q = max(*block_q_sizes)
365+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , bq )
366+
367+ #block_kv = max(*block_kv_sizes)
368+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , bkv )
369+ value , _ , _ = _pad_data_for_flash (value , heads , bkv )
370+
371+ bsizes = custom_splash ._BlockSizes (
372+ block_q = bq ,
373+ block_kv = bkv ,
374+ block_kv_compute = bkv_compute ,
349375 )
350- if uses_fused_kernel :
351- block_q_sizes += (block_sizes .block_q_dkv ,)
352- block_kv_sizes += (block_sizes .block_kv_dkv ,)
353- else :
354- block_q_sizes += (block_sizes .block_q_dq ,)
355- block_kv_sizes += (block_sizes .block_kv_dq ,)
356376
357- block_q = max (* block_q_sizes )
358- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
359-
360- block_kv = max (* block_kv_sizes )
361- key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
362- value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
363-
364- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
365- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
377+ # mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
378+ # multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
366379
367380 q_padded_len = query .shape [2 ]
368381 q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
@@ -373,24 +386,25 @@ def wrap_flash_attention(query, key, value):
373386 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
374387
375388 # If attention_mask is provided, apply it to kv_segment_ids
376- if attention_mask is not None :
377- mask_len = min (key_seq_len , attention_mask .shape [1 ])
378- kv_mask_for_batch = attention_mask [0 , :mask_len ] # (mask_len,)
379- # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
380- if key_seq_len > mask_len :
381- extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
382- kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 ) # (key_seq_len,)
383- # Pad to kv_padded_len
384- if kv_padded_len > key_seq_len :
385- padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
386- kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 ) # (kv_padded_len,)
387- else :
388- kv_mask_padded = kv_mask_for_batch
389- # Both are (kv_padded_len,) - element-wise multiplication
390- kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
389+ # if attention_mask is not None:
390+ # mask_len = min(key_seq_len, attention_mask.shape[1])
391+ # kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
392+ # # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
393+ # if key_seq_len > mask_len:
394+ # extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
395+ # kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
396+ # # Pad to kv_padded_len
397+ # if kv_padded_len > key_seq_len:
398+ # padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
399+ # kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
400+ # else:
401+ # kv_mask_padded = kv_mask_for_batch
402+ # # Both are (kv_padded_len,) - element-wise multiplication
403+ # kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
391404
392405 if attention_kernel == "tokamax_ring" :
393- segment_ids = tokamax_splash_base .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
406+ #segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
407+ pass
394408 else :
395409 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
396410
@@ -412,23 +426,16 @@ def wrap_flash_attention(query, key, value):
412426 save_residuals = False ,
413427 )
414428 elif attention_kernel == "tokamax_ring" :
415- mask = tokamax_splash_attention_mask .FullMask (
416- _shape = (query .shape [2 ], key .shape [2 ]),
429+ splash_kernel = custom_splash .make_splash_mha (
430+ block_sizes = bsizes ,
431+ bkv_compute_in = bkv_compute_in ,
432+ orig_q_seq_len = query_seq_len ,
433+ orig_kv_seq_len = key_seq_len ,
434+ heads_per_tile = heads_per_tile ,
417435 )
418- splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
419- mask = mask ,
420- is_mqa = False ,
421- config = convert_to_tokamax_splash_config (
422- block_sizes ,
423- residual_checkpoint_name = residual_checkpoint_name ,
424- use_base2_exp = use_base2_exp ,
425- use_experimental_scheduler = use_experimental_scheduler ,
426- ),
427- save_residuals = False ,
428- ring_axis = "context" ,
429- rotate_segment_ids = False , # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
430436 )
431437 else :
438+ splash_kernel = custom_splash
432439 splash_kernel = splash_attention_kernel .make_splash_mha (
433440 mask = multi_head_mask ,
434441 head_shards = 1 , # the sizes of the axis is sharding over heads
@@ -438,12 +445,14 @@ def wrap_flash_attention(query, key, value):
438445 residual_checkpoint_name = residual_checkpoint_name ,
439446 )
440447
441- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ) )
448+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ), out_axes = 0 )
442449
443450 if not mask_padding_tokens :
444451 segment_ids = None
445452 if attention_kernel in ["flash" , "tokamax_flash" , "tokamax_ring" ]:
446- attention_output = vmapped_splash (query , key , value , segment_ids )
453+ attention_output = vmapped_splash (query , key , value )
454+ if attention_kernel == "tokamax_ring" :
455+ attention_output = jnp .swapaxes (attention_output , 2 , 3 )
447456 else :
448457 if num_context_shards > 1 :
449458 out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
@@ -518,7 +527,7 @@ def _ulysses_attention(
518527 axis_names_kv : AxisNames ,
519528 flash_block_sizes : BlockSizes ,
520529 dtype : jnp .dtype = jnp .float32 ,
521- mask_padding_tokens : bool = True ,
530+ mask_padding_tokens : bool = False ,
522531 residual_checkpoint_name : str | None = None ,
523532 attention_mask : jax .Array = None ,
524533) -> jax .Array :
@@ -752,7 +761,7 @@ def _apply_attention(
752761 axis_names_kv : AxisNames ,
753762 flash_block_sizes : BlockSizes ,
754763 dpa_layer : Callable ,
755- mask_padding_tokens : bool = True ,
764+ mask_padding_tokens : bool = False ,
756765 residual_checkpoint_name : str | None = None ,
757766 attention_mask : Array = None ,
758767 use_base2_exp : bool = False ,
@@ -999,7 +1008,7 @@ def __init__(
9991008 flash_block_sizes : BlockSizes = None ,
10001009 dtype : DType = jnp .float32 ,
10011010 quant : Quant = None ,
1002- mask_padding_tokens : bool = True ,
1011+ mask_padding_tokens : bool = False ,
10031012 residual_checkpoint_name : str | None = None ,
10041013 use_base2_exp : bool = False ,
10051014 use_experimental_scheduler : bool = False ,
@@ -1167,7 +1176,7 @@ def __init__(
11671176 qkv_bias : bool = False ,
11681177 quant : Quant = None ,
11691178 is_self_attention : bool = True ,
1170- mask_padding_tokens : bool = True ,
1179+ mask_padding_tokens : bool = False ,
11711180 residual_checkpoint_name : str | None = None ,
11721181 enable_jax_named_scopes : bool = False ,
11731182 added_kv_proj_dim : Optional [int ] = None , # New for I2V
0 commit comments