@@ -187,9 +187,9 @@ def _tpu_flash_attention(
187187 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute , num_fsdp_shards )
188188 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
189189 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
190- # flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
191- # axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192- # named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
190+ flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
191+ axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
192+ named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
193193
194194 shard_head_size = mesh .shape ["tensor" ]
195195
@@ -210,7 +210,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
210210
211211 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
212212 splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
213- # segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
213+ segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
214214
215215 @functools .partial (
216216 shard_map .shard_map ,
@@ -219,7 +219,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
219219 q_axis_names ,
220220 kv_axis_names ,
221221 kv_axis_names ,
222- None ,
222+ segment_axis_names_splash_kernel ,
223223 ),
224224 out_specs = q_axis_names ,
225225 check_rep = False ,
@@ -511,8 +511,8 @@ def __init__(
511511 use_memory_efficient_attention : bool = False ,
512512 split_head_dim : bool = False ,
513513 float32_qk_product : bool = True ,
514- axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
515- axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV ),
514+ axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , None ),
515+ axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , None ),
516516 flash_min_seq_length : int = 4096 ,
517517 flash_block_sizes : BlockSizes = None ,
518518 dtype : DType = jnp .float32 ,
@@ -675,7 +675,7 @@ def __init__(
675675 quant = quant ,
676676 )
677677
678- kernel_axes = ("embed" , "heads" )
678+ kernel_axes = ("embed" , None )
679679 qkv_init_kernel = nnx .with_partitioning (nnx .initializers .lecun_normal (), kernel_axes )
680680
681681 self .query = nnx .Linear (
@@ -715,7 +715,7 @@ def __init__(
715715 rngs = rngs ,
716716 in_features = self .inner_dim ,
717717 out_features = self .inner_dim ,
718- kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("heads" , "embed" )),
718+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), (None , "embed" )),
719719 dtype = dtype ,
720720 param_dtype = weights_dtype ,
721721 precision = precision ,
0 commit comments