3131from einops import rearrange
3232from .. import common_types , max_logging
3333
34+ from . import custom_splash_attention as custom_splash
35+
3436from . import quantizations
3537from .modeling_flax_utils import get_activation
3638
@@ -311,7 +313,7 @@ def _tpu_flash_attention(
311313 flash_block_sizes : BlockSizes ,
312314 dtype : jnp .dtype = jnp .float32 ,
313315 attention_kernel : str = "flash" ,
314- mask_padding_tokens : bool = True ,
316+ mask_padding_tokens : bool = False ,
315317 residual_checkpoint_name : str | None = None ,
316318 attention_mask : jax .Array = None ,
317319) -> jax .Array :
@@ -334,31 +336,42 @@ def _tpu_flash_attention(
334336 check_rep = False ,
335337 )
336338 def wrap_flash_attention (query , key , value ):
337- uses_fused_kernel = block_sizes .use_fused_bwd_kernel
338- block_q_sizes = (
339- block_sizes .block_q ,
340- block_sizes .block_q_dkv ,
341- )
342- block_kv_sizes = (
343- block_sizes .block_kv ,
344- block_sizes .block_kv_dkv ,
339+ bq = 2048
340+ bkv = 2048
341+ bkv_compute = 1024
342+ bkv_compute_in = 256
343+ heads_per_tile = 1 # Matches Torchax default
344+ # uses_fused_kernel = block_sizes.use_fused_bwd_kernel
345+ # block_q_sizes = (
346+ # block_sizes.block_q,
347+ # block_sizes.block_q_dkv,
348+ # )
349+ # block_kv_sizes = (
350+ # block_sizes.block_kv,
351+ # block_sizes.block_kv_dkv,
352+ # )
353+ # if uses_fused_kernel:
354+ # block_q_sizes += (block_sizes.block_q_dkv,)
355+ # block_kv_sizes += (block_sizes.block_kv_dkv,)
356+ # else:
357+ # block_q_sizes += (block_sizes.block_q_dq,)
358+ # block_kv_sizes += (block_sizes.block_kv_dq,)
359+
360+ # block_q = max(*block_q_sizes)
361+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , bq )
362+
363+ #block_kv = max(*block_kv_sizes)
364+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , bkv )
365+ value , _ , _ = _pad_data_for_flash (value , heads , bkv )
366+
367+ bsizes = custom_splash ._BlockSizes (
368+ block_q = bq ,
369+ block_kv = bkv ,
370+ block_kv_compute = bkv_compute ,
345371 )
346- if uses_fused_kernel :
347- block_q_sizes += (block_sizes .block_q_dkv ,)
348- block_kv_sizes += (block_sizes .block_kv_dkv ,)
349- else :
350- block_q_sizes += (block_sizes .block_q_dq ,)
351- block_kv_sizes += (block_sizes .block_kv_dq ,)
352372
353- block_q = max (* block_q_sizes )
354- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
355-
356- block_kv = max (* block_kv_sizes )
357- key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
358- value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
359-
360- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
361- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
373+ # mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
374+ # multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
362375
363376 q_padded_len = query .shape [2 ]
364377 q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
@@ -369,24 +382,25 @@ def wrap_flash_attention(query, key, value):
369382 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
370383
371384 # If attention_mask is provided, apply it to kv_segment_ids
372- if attention_mask is not None :
373- mask_len = min (key_seq_len , attention_mask .shape [1 ])
374- kv_mask_for_batch = attention_mask [0 , :mask_len ] # (mask_len,)
375- # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
376- if key_seq_len > mask_len :
377- extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
378- kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 ) # (key_seq_len,)
379- # Pad to kv_padded_len
380- if kv_padded_len > key_seq_len :
381- padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
382- kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 ) # (kv_padded_len,)
383- else :
384- kv_mask_padded = kv_mask_for_batch
385- # Both are (kv_padded_len,) - element-wise multiplication
386- kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
385+ # if attention_mask is not None:
386+ # mask_len = min(key_seq_len, attention_mask.shape[1])
387+ # kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
388+ # # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
389+ # if key_seq_len > mask_len:
390+ # extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
391+ # kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
392+ # # Pad to kv_padded_len
393+ # if kv_padded_len > key_seq_len:
394+ # padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
395+ # kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
396+ # else:
397+ # kv_mask_padded = kv_mask_for_batch
398+ # # Both are (kv_padded_len,) - element-wise multiplication
399+ # kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
387400
388401 if attention_kernel == "tokamax_ring" :
389- segment_ids = tokamax_splash_base .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
402+ #segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
403+ pass
390404 else :
391405 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
392406
@@ -403,18 +417,26 @@ def wrap_flash_attention(query, key, value):
403417 save_residuals = False ,
404418 )
405419 elif attention_kernel == "tokamax_ring" :
406- mask = tokamax_splash_attention_mask .FullMask (
407- _shape = (query .shape [2 ], key .shape [2 ]),
408- )
409- splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
410- mask = mask ,
411- is_mqa = False ,
412- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
413- save_residuals = False ,
414- ring_axis = "context" ,
415- rotate_segment_ids = False , # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
420+ # mask = tokamax_splash_attention_mask.FullMask(
421+ # _shape=(query.shape[2], key.shape[2]),
422+ # )
423+ # splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
424+ # mask=mask,
425+ # is_mqa=False,
426+ # config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
427+ # save_residuals=False,
428+ # ring_axis="context",
429+ # rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
430+ # )
431+ splash_kernel = custom_splash .make_splash_mha (
432+ block_sizes = bsizes ,
433+ bkv_compute_in = bkv_compute_in ,
434+ orig_q_seq_len = query_seq_len ,
435+ orig_kv_seq_len = key_seq_len ,
436+ heads_per_tile = heads_per_tile
416437 )
417438 else :
439+ splash_kernel = custom_splash
418440 splash_kernel = splash_attention_kernel .make_splash_mha (
419441 mask = multi_head_mask ,
420442 head_shards = 1 , # the sizes of the axis is sharding over heads
@@ -424,12 +446,14 @@ def wrap_flash_attention(query, key, value):
424446 residual_checkpoint_name = residual_checkpoint_name ,
425447 )
426448
427- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ) )
449+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ), out_axes = 0 )
428450
429451 if not mask_padding_tokens :
430452 segment_ids = None
431453 if attention_kernel in ["flash" , "tokamax_flash" , "tokamax_ring" ]:
432- attention_output = vmapped_splash (query , key , value , segment_ids )
454+ attention_output = vmapped_splash (query , key , value )
455+ if attention_kernel == "tokamax_ring" :
456+ attention_output = jnp .swapaxes (attention_output , 2 , 3 )
433457 else :
434458 if num_context_shards > 1 :
435459 out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
@@ -504,7 +528,7 @@ def _ulysses_attention(
504528 axis_names_kv : AxisNames ,
505529 flash_block_sizes : BlockSizes ,
506530 dtype : jnp .dtype = jnp .float32 ,
507- mask_padding_tokens : bool = True ,
531+ mask_padding_tokens : bool = False ,
508532 residual_checkpoint_name : str | None = None ,
509533 attention_mask : jax .Array = None ,
510534) -> jax .Array :
@@ -738,7 +762,7 @@ def _apply_attention(
738762 axis_names_kv : AxisNames ,
739763 flash_block_sizes : BlockSizes ,
740764 dpa_layer : Callable ,
741- mask_padding_tokens : bool = True ,
765+ mask_padding_tokens : bool = False ,
742766 residual_checkpoint_name : str | None = None ,
743767 attention_mask : Array = None ,
744768):
@@ -981,7 +1005,7 @@ def __init__(
9811005 flash_block_sizes : BlockSizes = None ,
9821006 dtype : DType = jnp .float32 ,
9831007 quant : Quant = None ,
984- mask_padding_tokens : bool = True ,
1008+ mask_padding_tokens : bool = False ,
9851009 residual_checkpoint_name : str | None = None ,
9861010 ):
9871011 self .dpa_layer = None
@@ -1139,7 +1163,7 @@ def __init__(
11391163 qkv_bias : bool = False ,
11401164 quant : Quant = None ,
11411165 is_self_attention : bool = True ,
1142- mask_padding_tokens : bool = True ,
1166+ mask_padding_tokens : bool = False ,
11431167 residual_checkpoint_name : str | None = None ,
11441168 enable_jax_named_scopes : bool = False ,
11451169 added_kv_proj_dim : Optional [int ] = None , # New for I2V
0 commit comments