@@ -267,21 +267,38 @@ def _tpu_flash_attention(
267267 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
268268 )
269269 num_context_shards = mesh .shape ["context" ]
270- query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_context_shards )
271- key , _ = _reshape_data_for_flash (key , heads , num_context_shards )
272- value , _ = _reshape_data_for_flash (value , heads , num_context_shards )
273-
274- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
275- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
270+ def _pad_3d (tensor , num_shards ):
271+ org_len = tensor .shape [1 ]
272+ rem = org_len % num_shards
273+ if rem == 0 :
274+ return tensor , org_len
275+ pad_width = [(0 , 0 )] * tensor .ndim
276+ pad_width [1 ] = (0 , num_shards - rem )
277+ return jnp .pad (tensor , pad_width ), org_len
278+
279+ query , orig_q_seq_len = _pad_3d (query , num_context_shards )
280+ key , _ = _pad_3d (key , num_context_shards )
281+ value , _ = _pad_3d (value , num_context_shards )
282+
283+ # Define 3D sharding specs (Batch, Seq, None)
284+ q_axis_names_3d = nn .logical_to_mesh_axes ((axis_names_q [0 ], axis_names_q [2 ], None ))
285+ kv_axis_names_3d = nn .logical_to_mesh_axes ((axis_names_kv [0 ], axis_names_kv [2 ], None ))
286+
287+ # Output spec is still 4D [Batch, Heads, Seq, HeadDim]
288+ q_axis_names_4d = nn .logical_to_mesh_axes (axis_names_q )
276289
277290 @functools .partial (
278291 shard_map .shard_map ,
279292 mesh = mesh ,
280- in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
281- out_specs = q_axis_names ,
293+ in_specs = (q_axis_names_3d , kv_axis_names_3d , kv_axis_names_3d ),
294+ out_specs = q_axis_names_4d ,
282295 check_rep = False ,
283296 )
284297 def wrap_flash_attention (query , key , value ):
298+ # Reshape to 4D inside shard_map to avoid All-Gather during transpose
299+ query = _unflatten_heads (query , heads )
300+ key = _unflatten_heads (key , heads )
301+ value = _unflatten_heads (value , heads )
285302 uses_fused_kernel = block_sizes .use_fused_bwd_kernel
286303 block_q_sizes = (
287304 block_sizes .block_q ,
0 commit comments