2828from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
3030from tokamax ._src .ops .experimental .tpu .splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
31- from tokamax ._src .ops .experimental .tpu .splash_attention import base as tokamax_base
3231from einops import rearrange
3332from .. import common_types , max_logging
3433
@@ -280,49 +279,17 @@ def _tpu_flash_attention(
280279 query = _reshape_data_for_flash (query , heads )
281280 key = _reshape_data_for_flash (key , heads )
282281 value = _reshape_data_for_flash (value , heads )
283-
284- # Pre-padding and Ring Kernel creation outside shard_map
285- if attention_kernel == "tokamax_ring" :
286- block_q = max (block_sizes .block_q , block_sizes .block_q_dkv )
287- block_kv = max (block_sizes .block_kv , block_sizes .block_kv_dkv )
288-
289- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q , num_shards = num_fsdp_shards )
290- key , _ , _ = _pad_data_for_flash (key , heads , block_kv , num_shards = num_fsdp_shards )
291- value , _ , _ = _pad_data_for_flash (value , heads , block_kv , num_shards = num_fsdp_shards )
292-
293- mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
294- ring_kernel = tokamax_ring_attention_kernel .make_ring_attention (
295- mask = mask ,
296- is_mqa = False ,
297- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
298- save_residuals = True ,
299- ring_axis = "fsdp" ,
300- q_seq_shards = num_fsdp_shards ,
301- kv_seq_shards = num_fsdp_shards ,
302- )
303- kernel_spec = ring_kernel .manual_sharding_spec ()
304- else :
305- # Logic for other kernels remains unchanged regarding local padding
306- ring_kernel = None
307- kernel_spec = None
308-
309282 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
310283 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
311284
312285 @functools .partial (
313286 shard_map .shard_map ,
314287 mesh = mesh ,
315- in_specs = (q_axis_names , kv_axis_names , kv_axis_names , kernel_spec ),
288+ in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
316289 out_specs = q_axis_names ,
317290 check_rep = False ,
318291 )
319- def wrap_flash_attention (query , key , value , ring_kernel ):
320-
321- if attention_kernel == "tokamax_ring" :
322- # For bidirectional attention, segment_ids can be None to hit the performance shortcut
323- segment_ids = None
324- vmapped_splash = jax .vmap (ring_kernel , in_axes = (0 , 0 , 0 , None ))
325- return vmapped_splash (query , key , value , segment_ids )
292+ def wrap_flash_attention (query , key , value ):
326293
327294 uses_fused_kernel = block_sizes .use_fused_bwd_kernel
328295 block_q_sizes = (
@@ -357,7 +324,6 @@ def wrap_flash_attention(query, key, value, ring_kernel):
357324 kv_padded_len = key .shape [2 ]
358325 kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
359326 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
360-
361327 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
362328
363329 # make_splash_mha is wrapped around shardmap and seq and head is already
@@ -370,6 +336,15 @@ def wrap_flash_attention(query, key, value, ring_kernel):
370336 config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
371337 save_residuals = True if "ring" in attention_kernel else False ,
372338 )
339+ elif attention_kernel == "tokamax_ring" :
340+ mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
341+ splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
342+ mask = mask ,
343+ is_mqa = False ,
344+ config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
345+ save_residuals = True ,
346+ ring_axis = "fsdp" ,
347+ )
373348 else :
374349 splash_kernel = splash_attention_kernel .make_splash_mha (
375350 mask = multi_head_mask ,
@@ -385,7 +360,7 @@ def wrap_flash_attention(query, key, value, ring_kernel):
385360
386361 if not mask_padding_tokens :
387362 segment_ids = None
388- if attention_kernel in ["flash" , "tokamax_flash" ]:
363+ if attention_kernel in ["flash" , "tokamax_flash" , "tokamax_ring" ]:
389364 attention_output = vmapped_splash (query , key , value , segment_ids )
390365 else :
391366 if num_fsdp_shards > 1 :
@@ -437,11 +412,7 @@ def ring_scan_body(carry, _):
437412 "Warning, batch dimension should be shardable among the devices in data and fsdp"
438413 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
439414 )
440- x = wrap_flash_attention (query , key , value , ring_kernel )
441-
442- if attention_kernel == "tokamax_ring" :
443- x = x [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
444-
415+ x = wrap_flash_attention (query , key , value )
445416 x = _reshape_heads_to_head_dim (x )
446417
447418 return x
0 commit comments