@@ -258,6 +258,26 @@ def _tpu_flash_attention(
258258 query = _reshape_data_for_flash (query , heads )
259259 key = _reshape_data_for_flash (key , heads )
260260 value = _reshape_data_for_flash (value , heads )
261+
262+ # Pad sequence dimension so it is evenly divisible by the context mesh axis,
263+ # which shard_map requires. The output is trimmed back afterwards and the
264+ # existing segment-ID masking inside wrap_flash_attention ensures padded
265+ # positions do not affect the result.
266+ orig_q_seq_len = query .shape [2 ]
267+ if num_context_shards > 1 :
268+
269+ def _pad_seq_to_context (arr , axis = 2 ):
270+ rem = arr .shape [axis ] % num_context_shards
271+ if rem == 0 :
272+ return arr
273+ pad_width = [(0 , 0 )] * arr .ndim
274+ pad_width [axis ] = (0 , num_context_shards - rem )
275+ return jnp .pad (arr , pad_width )
276+
277+ query = _pad_seq_to_context (query )
278+ key = _pad_seq_to_context (key )
279+ value = _pad_seq_to_context (value )
280+
261281 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
262282 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
263283
@@ -401,6 +421,8 @@ def ring_scan_body(carry, _):
401421 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_context: { devices_in_data_context } "
402422 )
403423 x = wrap_flash_attention (query , key , value )
424+ # Trim back to original sequence length after context-axis padding.
425+ x = x [:, :, :orig_q_seq_len , :]
404426 x = _reshape_heads_to_head_dim (x )
405427
406428 return x
0 commit comments