2121import jax
2222from jax .ad_checkpoint import checkpoint_name
2323import jax .numpy as jnp
24- from jax .experimental import shard_map
2524from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
2625from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
2726from maxdiffusion .kernels .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2827from maxdiffusion .kernels .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
2928from maxdiffusion .kernels .splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
29+ from jax .experimental .shard_map import shard_map
3030from maxdiffusion .kernels .splash_attention import base as tokamax_splash_base
3131from einops import rearrange
3232from .. import common_types , max_logging
@@ -303,6 +303,7 @@ def convert_to_tokamax_splash_config(
303303 dq_reduction_steps = dq_reduction_steps ,
304304 )
305305
306+
306307def _tpu_flash_attention (
307308 query : jax .Array ,
308309 key : jax .Array ,
@@ -314,82 +315,88 @@ def _tpu_flash_attention(
314315 flash_block_sizes : BlockSizes ,
315316 dtype : jnp .dtype = jnp .float32 ,
316317 attention_kernel : str = "flash" ,
317- mask_padding_tokens : bool = False ,
318+ mask_padding_tokens : bool = True ,
318319 residual_checkpoint_name : str | None = None ,
319320 attention_mask : jax .Array = None ,
320321 use_base2_exp : bool = False ,
321322 use_experimental_scheduler : bool = False ,
322323) -> jax .Array :
323- """Torchax-Style Tensor Parallel TPU Flash Attention"""
324-
325- # 1. Reshape from (Batch, Seq_Len, Heads*Dim) -> (Batch, Heads, Seq_Len, Dim)
326- # We pass num_context_shards=1 because we hold the full sequence locally.
327- query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_context_shards = 1 )
328- key , _ = _reshape_data_for_flash (key , heads , num_context_shards = 1 )
329- value , _ = _reshape_data_for_flash (value , heads , num_context_shards = 1 )
330-
331- # 2. Call the built-in Torchax shard_map wrapper!
332- # This automatically handles the SPMD boundaries, the min() padding logic for
333- # cross-attention, and the axis swapping.
334- x = custom_splash .tpu_custom_attention (
335- query = query ,
336- key = key ,
337- value = value ,
338- mesh = mesh ,
339- # CRITICAL: JAX already scaled `key` by 1/sqrt(d) in the router.
340- # Passing 1.0 ensures the kernel only applies the log2(e) multiplier to `query`.
341- scale = 1.0 ,
342- block_q = 2048 ,
343- block_kv = 2048 ,
344- block_kv_compute = 1024 ,
345- block_kv_compute_in = 256 ,
346- heads_per_tile = 1
324+ """TPU Flash Attention"""
325+ num_context_shards = mesh .shape ["context" ]
326+ query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_context_shards )
327+ key , _ = _reshape_data_for_flash (key , heads , num_context_shards )
328+ value , _ = _reshape_data_for_flash (value , heads , num_context_shards )
329+ block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , attention_kernel )
330+
331+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
332+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
333+
334+ @functools .partial (
335+ shard_map ,
336+ mesh = mesh ,
337+ in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
338+ out_specs = q_axis_names ,
339+ check_rep = False ,
340+ )
341+ def wrap_flash_attention (query , key , value ):
342+ uses_fused_kernel = block_sizes .use_fused_bwd_kernel
343+ block_q_sizes = (
344+ block_sizes .block_q ,
345+ block_sizes .block_q_dkv ,
347346 )
347+ block_kv_sizes = (
348+ block_sizes .block_kv ,
349+ block_sizes .block_kv_dkv ,
350+ )
351+ if uses_fused_kernel :
352+ block_q_sizes += (block_sizes .block_q_dkv ,)
353+ block_kv_sizes += (block_sizes .block_kv_dkv ,)
354+ else :
355+ block_q_sizes += (block_sizes .block_q_dq ,)
356+ block_kv_sizes += (block_sizes .block_kv_dq ,)
357+ block_q = max (* block_q_sizes )
358+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
359+
360+ block_kv = max (* block_kv_sizes )
361+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
362+ value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
363+
364+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
365+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
366+
367+ q_padded_len = query .shape [2 ]
368+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
369+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
370+
371+ kv_padded_len = key .shape [2 ]
372+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
373+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
374+
375+ if attention_mask is not None :
376+ mask_len = min (key_seq_len , attention_mask .shape [1 ])
377+ kv_mask_for_batch = attention_mask [0 , :mask_len ]
378+ if key_seq_len > mask_len :
379+ extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
380+ kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 )
381+ if kv_padded_len > key_seq_len :
382+ padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
383+ kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 )
384+ else :
385+ kv_mask_padded = kv_mask_for_batch
386+ kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
387+
388+ if attention_kernel == "tokamax_ring" :
389+ segment_ids = tokamax_splash_base .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
390+ else :
391+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
348392
349- # 3. Trim back to original sequence length
350- x = x [:, :, :orig_q_seq_len , :]
351-
352- # 4. Flatten back to (Batch, Seq_Len, Heads * Head_Dim)
353- x = _reshape_heads_to_head_dim (x )
354-
355- return x
356-
357-
358- # def _tpu_flash_attention(
359- # query: jax.Array,
360- # key: jax.Array,
361- # value: jax.Array,
362- # heads: int,
363- # mesh: Mesh,
364- # axis_names_q: AxisNames,
365- # axis_names_kv: AxisNames,
366- # flash_block_sizes: BlockSizes,
367- # dtype: jnp.dtype = jnp.float32,
368- # attention_kernel: str = "flash",
369- # mask_padding_tokens: bool = False,
370- # residual_checkpoint_name: str | None = None,
371- # attention_mask: jax.Array = None,
372- # ) -> jax.Array:
373- # """TPU Flash Attention"""
374-
375- # num_context_shards = mesh.shape["context"]
376- # query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
377- # key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
378- # value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
379- # block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
380-
381- # q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
382- # kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
383-
384- # make_splash_mha is wrapped around shardmap and seq and head is already
385- # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
386393 if attention_kernel == "tokamax_flash" :
387394 mask = tokamax_splash_attention_mask .FullMask (
388395 _shape = (query .shape [2 ], key .shape [2 ]),
389396 )
390397 splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
391398 mask = mask ,
392- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
399+ q_seq_shards = 1 ,
393400 config = convert_to_tokamax_splash_config (
394401 block_sizes ,
395402 residual_checkpoint_name = residual_checkpoint_name ,
@@ -413,7 +420,7 @@ def _tpu_flash_attention(
413420 ),
414421 save_residuals = False ,
415422 ring_axis = "context" ,
416- rotate_segment_ids = False , # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
423+ rotate_segment_ids = False ,
417424 )
418425 else :
419426 splash_kernel = splash_attention_kernel .make_splash_mha (
@@ -425,134 +432,65 @@ def _tpu_flash_attention(
425432 residual_checkpoint_name = residual_checkpoint_name ,
426433 )
427434
428- # query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
429-
430- # #block_kv = max(*block_kv_sizes)
431- # key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
432- # value, _, _ = _pad_data_for_flash(value, heads, bkv)
433-
434- # bsizes = custom_splash._BlockSizes(
435- # block_q=bq,
436- # block_kv=bkv,
437- # block_kv_compute=bkv_compute,
438- # )
439-
440- # # mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
441- # # multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
442-
443- # q_padded_len = query.shape[2]
444- # q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
445- # q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
446-
447- # kv_padded_len = key.shape[2]
448- # kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
449- # kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
450-
451- # # make_splash_mha is wrapped around shardmap and seq and head is already
452- # # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
453- # if attention_kernel == "tokamax_flash":
454- # mask = tokamax_splash_attention_mask.FullMask(
455- # _shape=(query.shape[2], key.shape[2]),
456- # )
457- # splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
458- # mask=mask,
459- # q_seq_shards=1, # the sizes of the axis is sharding over seq_len
460- # config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
461- # save_residuals=False,
462- # )
463- # elif attention_kernel == "tokamax_ring":
464- # splash_kernel = custom_splash.make_splash_mha(
465- # block_sizes=bsizes,
466- # bkv_compute_in=bkv_compute_in,
467- # orig_q_seq_len=query_seq_len,
468- # orig_kv_seq_len=key_seq_len,
469- # heads_per_tile=heads_per_tile
470- # )
471- # else:
472- # splash_kernel = custom_splash.make_splash_mha(
473- # block_sizes=bsizes,
474- # bkv_compute_in=bkv_compute_in,
475- # orig_q_seq_len=query_seq_len,
476- # orig_kv_seq_len=key_seq_len,
477- # heads_per_tile=heads_per_tile
478- # )
479-
480- # vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0), out_axes=0)
481-
482- # if not mask_padding_tokens:
483- # segment_ids = None
484- # if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
485- # query_scaled = query * 1.44269504
486- # attention_output = vmapped_splash(query_scaled, key, value)
487- # attention_output = jnp.swapaxes(attention_output, 2, 3)
488- # else:
489- # if num_context_shards > 1:
490- # out, (lse,) = vmapped_splash(query, key, value, segment_ids)
491- # m = lse.astype(jnp.float32)
492- # l = jnp.exp(lse - m)
493- # o = out.astype(jnp.float32) * l[..., None]
494-
495- # perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)]
496-
497- # k1 = jax.lax.ppermute(key, axis_name="context", perm=perm)
498- # v1 = jax.lax.ppermute(value, axis_name="context", perm=perm)
499-
500- # def ring_scan_body(carry, _):
501- # m, l, o, k_current, v_current = carry
502- # k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm)
503- # v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm)
435+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
436+
437+ if not mask_padding_tokens :
438+ segment_ids = None
439+ if attention_kernel in ["flash" , "tokamax_flash" , "tokamax_ring" ]:
440+ attention_output = vmapped_splash (query , key , value , segment_ids )
441+ else :
442+ if num_context_shards > 1 :
443+ out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
444+ m = lse .astype (jnp .float32 )
445+ l = jnp .exp (lse - m )
446+ o = out .astype (jnp .float32 ) * l [..., None ]
447+
448+ perm = [(j , (j + 1 ) % num_context_shards ) for j in range (num_context_shards )]
449+
450+ k1 = jax .lax .ppermute (key , axis_name = "context" , perm = perm )
451+ v1 = jax .lax .ppermute (value , axis_name = "context" , perm = perm )
452+
453+ def ring_scan_body (carry , _ ):
454+ m , l , o , k_current , v_current = carry
455+ k_next = jax .lax .ppermute (k_current , axis_name = "context" , perm = perm )
456+ v_next = jax .lax .ppermute (v_current , axis_name = "context" , perm = perm )
457+
458+ out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
459+
460+ m_chunk = lse_chunk .astype (jnp .float32 )
461+ m_old = m
462+ m = jnp .maximum (m_old , m_chunk )
463+
464+ exp_m_diff = jnp .exp (m_old - m )
465+ exp_m_chunk_diff = jnp .exp (m_chunk - m )
466+
467+ l = l * exp_m_diff + jnp .exp (lse_chunk - m )
468+ o = o * exp_m_diff [..., None ]
469+ o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
470+
471+ return (m , l , o , k_next , v_next ), None
472+
473+ initial_carry = (m , l , o , k1 , v1 )
474+ (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (
475+ ring_scan_body , initial_carry , None , length = num_context_shards - 1
476+ )
477+
478+ attention_output = o_final / l_final [..., None ]
479+ else :
480+ raise ValueError ("ring attention requires context > 1" )
481+ return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
504482
505483 devices_in_batch_sharding = mesh .shape ["data" ] * (mesh .shape ["fsdp" ] if "fsdp" in mesh .shape else 1 )
506- # This warning might show up when doing model eval for example, when calculating model flops
507- # and that is expected.
508484 if not (query .shape [0 ] / devices_in_batch_sharding ).is_integer ():
509485 max_logging .log (
510486 "Warning, batch dimension should be shardable among the devices in data and fsdp"
511487 f" axis, batch dimension: { query .shape [0 ]} , devices_in_batch_sharding: { devices_in_batch_sharding } "
512488 )
513489 x = wrap_flash_attention (query , key , value )
514- # Trim back to original sequence length after context-axis padding.
515490 x = x [:, :, :orig_q_seq_len , :]
516491 x = _reshape_heads_to_head_dim (x )
517492
518- # m_chunk = lse_chunk.astype(jnp.float32)
519- # m_old = m
520- # m = jnp.maximum(m_old, m_chunk)
521-
522- # exp_m_diff = jnp.exp(m_old - m)
523- # exp_m_chunk_diff = jnp.exp(m_chunk - m)
524-
525- # l = l * exp_m_diff + jnp.exp(lse_chunk - m)
526- # o = o * exp_m_diff[..., None]
527- # o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
528-
529- # # Return the updated state for the next iteration
530- # return (m, l, o, k_next, v_next), None
531-
532- # initial_carry = (m, l, o, k1, v1)
533- # (m_final, l_final, o_final, _, _), _ = jax.lax.scan(
534- # ring_scan_body, initial_carry, None, length=num_context_shards - 1
535- # )
536-
537- # attention_output = o_final / l_final[..., None]
538- # else:
539- # raise ValueError("ring attention requires context > 1")
540- # return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
541-
542- # devices_in_data_context = mesh.shape["data"] * mesh.shape["context"]
543- # # This warning might show up when doing model eval for example, when calculating model flops
544- # # and that is expected.
545- # if not (query.shape[0] / devices_in_data_context).is_integer():
546- # max_logging.log(
547- # "Warning, batch dimension should be shardable among the devices in data and context"
548- # f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
549- # )
550- # x = wrap_flash_attention(query, key, value)
551- # # Trim back to original sequence length after context-axis padding.
552- # x = x[:, :, :orig_q_seq_len, :]
553- # x = _reshape_heads_to_head_dim(x)
554-
555- # return x
493+ return x
556494
557495
558496# ---------------------------------------------------------------------------
@@ -720,7 +658,7 @@ def _ulysses_custom_attention(
720658 "Ulysses attention requires the number of heads to be divisible by the context shard count, "
721659 f"got heads={ num_heads } and context_shards={ num_shards } ."
722660 )
723-
661+
724662 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
725663 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
726664
@@ -755,7 +693,7 @@ def wrap_ulysses_attention(query, key, value):
755693 bkv_compute_in = bkv_compute_in ,
756694 orig_q_seq_len = query_seq_len ,
757695 orig_kv_seq_len = key_seq_len ,
758- heads_per_tile = heads_per_tile
696+ heads_per_tile = heads_per_tile ,
759697 )
760698
761699 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ))
@@ -779,11 +717,11 @@ def wrap_ulysses_attention(query, key, value):
779717
780718 return x
781719
782- def torchax_attention (
783-
784- ):
720+
721+ def torchax_attention ():
785722 pass
786723
724+
787725def _apply_attention_dot (
788726 query : Array ,
789727 key : Array ,
0 commit comments