3838AxisNames = common_types .AxisNames
3939BATCH = common_types .BATCH
4040LENGTH = common_types .LENGTH
41+ KV_LENGTH = common_types .KV_LENGTH
4142HEAD = common_types .HEAD
4243D_KV = common_types .D_KV
4344EMBED = common_types .EMBED
@@ -156,7 +157,8 @@ def _tpu_flash_attention(
156157 value : jax .Array ,
157158 heads : int ,
158159 mesh : Mesh ,
159- flash_axis_names : AxisNames ,
160+ axis_names_q : AxisNames ,
161+ axis_names_kv : AxisNames ,
160162 flash_block_sizes : BlockSizes ,
161163 dtype : jnp .dtype = jnp .float32 ,
162164) -> jax .Array :
@@ -181,8 +183,8 @@ def _tpu_flash_attention(
181183 query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q , num_fsdp_shards )
182184 key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute , num_fsdp_shards )
183185 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute , num_fsdp_shards )
184- axis_names = nn .logical_to_mesh_axes (flash_axis_names )
185- kv_axis_names = nn .logical_to_mesh_axes (( BATCH , HEAD , None , D_KV ) )
186+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
187+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
186188 flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
187189 axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
188190 named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
@@ -200,7 +202,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
200202 splash_kernel = splash_attention_kernel .make_splash_mha (
201203 mask = multi_head_mask ,
202204 head_shards = shard_head_size , # the sizes of the axis is sharding over heads
203- q_seq_shards = num_fsdp_shards ,
205+ q_seq_shards = num_fsdp_shards , # the sizes of the axis is sharding over seq_len
204206 block_sizes = block_sizes ,
205207 )
206208 return splash_kernel
@@ -213,12 +215,12 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
213215 shard_map .shard_map ,
214216 mesh = mesh ,
215217 in_specs = (
216- axis_names ,
218+ q_axis_names ,
217219 kv_axis_names ,
218220 kv_axis_names ,
219221 segment_axis_names_splash_kernel ,
220222 ),
221- out_specs = axis_names ,
223+ out_specs = q_axis_names ,
222224 check_rep = False
223225 )
224226 def wrap_flash_attention (query , key , value , splash_kernel ):
@@ -359,7 +361,8 @@ def _apply_attention(
359361 scale : float ,
360362 dtype : jnp .dtype ,
361363 mesh : Mesh ,
362- flash_axis_names : AxisNames ,
364+ axis_names_q : AxisNames ,
365+ axis_names_kv : AxisNames ,
363366 flash_block_sizes : BlockSizes ,
364367 dpa_layer : Callable ,
365368):
@@ -382,7 +385,7 @@ def _apply_attention(
382385 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
383386 )
384387 elif attention_kernel == "flash" :
385- return _tpu_flash_attention (query , key * scale , value , heads , mesh , flash_axis_names , flash_block_sizes , dtype )
388+ return _tpu_flash_attention (query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype )
386389 elif attention_kernel == "cudnn_flash_te" :
387390 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
388391 else :
@@ -505,7 +508,8 @@ def __init__(
505508 use_memory_efficient_attention : bool = False ,
506509 split_head_dim : bool = False ,
507510 float32_qk_product : bool = True ,
508- flash_axis_names : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
511+ axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
512+ axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV ),
509513 flash_min_seq_length : int = 4096 ,
510514 flash_block_sizes : BlockSizes = None ,
511515 dtype : DType = jnp .float32 ,
@@ -523,7 +527,8 @@ def __init__(
523527 self .use_memory_efficient_attention = use_memory_efficient_attention
524528 self .split_head_dim = split_head_dim
525529 self .float32_qk_product = float32_qk_product
526- self .flash_axis_names = flash_axis_names
530+ self .axis_names_q = axis_names_q
531+ self .axis_names_kv = axis_names_kv
527532 self .flash_min_seq_length = flash_min_seq_length
528533 self .flash_block_sizes = flash_block_sizes
529534 self .dtype = dtype
@@ -544,7 +549,8 @@ def apply_attention(self, query: Array, key: Array, value: Array):
544549 scale = self .scale ,
545550 dtype = self .dtype ,
546551 mesh = self .mesh ,
547- flash_axis_names = self .flash_axis_names ,
552+ axis_names_q = self .axis_names_q ,
553+ axis_names_kv = self .axis_names_kv ,
548554 flash_block_sizes = self .flash_block_sizes ,
549555 dpa_layer = self .dpa_layer ,
550556 )
@@ -559,7 +565,8 @@ class AttentionOp(nn.Module):
559565 use_memory_efficient_attention : bool = False
560566 split_head_dim : bool = False
561567 float32_qk_product : bool = True
562- flash_axis_names : AxisNames = (BATCH , HEAD , LENGTH , D_KV )
568+ axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
569+ axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV ),
563570 flash_min_seq_length : int = 4096
564571 flash_block_sizes : BlockSizes = None
565572 dtype : DType = jnp .float32
@@ -600,7 +607,8 @@ def apply_attention(self, query: Array, key: Array, value: Array):
600607 scale = self .scale ,
601608 dtype = self .dtype ,
602609 mesh = self .mesh ,
603- flash_axis_names = self .flash_axis_names ,
610+ axis_names_q = self .axis_names_q ,
611+ axis_names_kv = self .axis_names_kv ,
604612 flash_block_sizes = self .flash_block_sizes ,
605613 dpa_layer = self .dpa_layer ,
606614 )
@@ -764,9 +772,6 @@ def __call__(
764772 key_proj = _unflatten_heads (key_proj , self .heads )
765773 value_proj = _unflatten_heads (value_proj , self .heads )
766774 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
767- query_proj = jax .lax .with_sharding_constraint (query_proj , PartitionSpec ("data" , "tensor" , None , None ))
768- key_proj = jax .lax .with_sharding_constraint (key_proj , PartitionSpec ("data" , "tensor" , None , None ))
769- value_proj = jax .lax .with_sharding_constraint (value_proj , PartitionSpec ("data" , "tensor" , None , None ))
770775
771776 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
772777 attn_output = jax .lax .with_sharding_constraint (attn_output , PartitionSpec ("data" , None , None ))
0 commit comments