2727from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
2828from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
30- from tokamax ._src .ops .experimental .tpu .splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
3130from einops import rearrange
3231from .. import common_types , max_logging
3332
@@ -306,92 +305,62 @@ def wrap_flash_attention(query, key, value):
306305 mask = mask ,
307306 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
308307 config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
309- save_residuals = True if "ring" in attention_kernel else False ,
310- )
311- elif attention_kernel == "tokamax_ring" :
312- mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
313- splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
314- mask = mask ,
315- is_mqa = False ,
316- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
317- save_residuals = True ,
318- ring_axis = "fsdp" ,
308+ save_residuals = True if attention_kernel == "ring" else False ,
319309 )
320310 else :
321311 splash_kernel = splash_attention_kernel .make_splash_mha (
322312 mask = multi_head_mask ,
323313 head_shards = 1 , # the sizes of the axis is sharding over heads
324314 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
325315 block_sizes = block_sizes ,
326- save_residuals = True if "ring" in attention_kernel else False ,
316+ save_residuals = True if attention_kernel == "ring" else False ,
327317 residual_checkpoint_name = residual_checkpoint_name
328318 )
319+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
329320
330- if attention_kernel == "tokamax_ring" :
331- # For tokamax_ring, use the kernel directly without vmap
332- # The ring attention kernel handles the ring topology internally
333- if not mask_padding_tokens :
334- segment_ids = None
335- attention_output = splash_kernel (
336- fwd_mask_info = None ,
337- dkv_mask_info = None ,
338- q = query ,
339- k = key ,
340- v = value ,
341- segment_ids = segment_ids ,
342- is_mqa = False ,
343- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
344- mask_value = - jnp .inf ,
345- mask_function = None ,
346- fwd_mask_sparsity = 1.0 ,
347- save_residuals = True ,
348- )
321+ if not mask_padding_tokens :
322+ segment_ids = None
323+ if attention_kernel in ["flash" , "tokamax_flash" ]:
324+ attention_output = vmapped_splash (query , key , value , segment_ids )
349325 else :
350- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
351-
352- if not mask_padding_tokens :
353- segment_ids = None
354- if attention_kernel in ["flash" , "tokamax_flash" ]:
355- attention_output = vmapped_splash (query , key , value , segment_ids )
356- else :
357- if num_fsdp_shards > 1 :
358- out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
359- m = lse .astype (jnp .float32 )
360- l = jnp .exp (lse - m )
361- o = out .astype (jnp .float32 ) * l [..., None ]
326+ if num_fsdp_shards > 1 :
327+ out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
328+ m = lse .astype (jnp .float32 )
329+ l = jnp .exp (lse - m )
330+ o = out .astype (jnp .float32 ) * l [..., None ]
362331
363- perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
332+ perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
364333
365- k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
366- v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
334+ k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
335+ v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
367336
368- def ring_scan_body (carry , _ ):
369- m , l , o , k_current , v_current = carry
370- k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
371- v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
337+ def ring_scan_body (carry , _ ):
338+ m , l , o , k_current , v_current = carry
339+ k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
340+ v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
372341
373- out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
342+ out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
374343
375- m_chunk = lse_chunk .astype (jnp .float32 )
376- m_old = m
377- m = jnp .maximum (m_old , m_chunk )
344+ m_chunk = lse_chunk .astype (jnp .float32 )
345+ m_old = m
346+ m = jnp .maximum (m_old , m_chunk )
378347
379- exp_m_diff = jnp .exp (m_old - m )
380- exp_m_chunk_diff = jnp .exp (m_chunk - m )
348+ exp_m_diff = jnp .exp (m_old - m )
349+ exp_m_chunk_diff = jnp .exp (m_chunk - m )
381350
382- l = l * exp_m_diff + jnp .exp (lse_chunk - m )
383- o = o * exp_m_diff [..., None ]
384- o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
351+ l = l * exp_m_diff + jnp .exp (lse_chunk - m )
352+ o = o * exp_m_diff [..., None ]
353+ o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
385354
386- # Return the updated state for the next iteration
387- return (m , l , o , k_next , v_next ), None
355+ # Return the updated state for the next iteration
356+ return (m , l , o , k_next , v_next ), None
388357
389- initial_carry = (m , l , o , k1 , v1 )
390- (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
358+ initial_carry = (m , l , o , k1 , v1 )
359+ (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
391360
392- attention_output = o_final / l_final [..., None ]
393- else :
394- raise ValueError ("ring attention requires fsdp > 1" )
361+ attention_output = o_final / l_final [..., None ]
362+ else :
363+ raise ValueError ("ring attention requires fsdp > 1" )
395364
396365 return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
397366
@@ -567,7 +536,7 @@ def _apply_attention(
567536 mask_padding_tokens = mask_padding_tokens ,
568537 residual_checkpoint_name = residual_checkpoint_name ,
569538 )
570- elif "ring" in attention_kernel :
539+ elif attention_kernel == "ring" :
571540 return _tpu_flash_attention (
572541 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
573542 mask_padding_tokens = mask_padding_tokens ,
@@ -578,7 +547,6 @@ def _apply_attention(
578547 raise ValueError (f"Unexpected attention kernel { attention_kernel = } ." )
579548
580549
581-
582550def _query_chunk_attention (query , key , value , precision , key_chunk_size : int = 4096 ):
583551 """Multi-head dot product attention with a limited number of queries."""
584552 num_kv , num_heads , k_features = key .shape [- 3 :]
0 commit comments