@@ -258,6 +258,24 @@ def _tpu_flash_attention(
258258 query = _reshape_data_for_flash (query , heads )
259259 key = _reshape_data_for_flash (key , heads )
260260 value = _reshape_data_for_flash (value , heads )
261+
262+ # Pad sequence dimension so it is evenly divisible by the context mesh axis,
263+ # which shard_map requires. The output is trimmed back afterwards and the
264+ # existing segment-ID masking inside wrap_flash_attention ensures padded
265+ # positions do not affect the result.
266+ orig_q_seq_len = query .shape [2 ]
267+ if num_context_shards > 1 :
268+ def _pad_seq_to_context (arr , axis = 2 ):
269+ rem = arr .shape [axis ] % num_context_shards
270+ if rem == 0 :
271+ return arr
272+ pad_width = [(0 , 0 )] * arr .ndim
273+ pad_width [axis ] = (0 , num_context_shards - rem )
274+ return jnp .pad (arr , pad_width )
275+ query = _pad_seq_to_context (query )
276+ key = _pad_seq_to_context (key )
277+ value = _pad_seq_to_context (value )
278+
261279 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
262280 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
263281
@@ -401,6 +419,8 @@ def ring_scan_body(carry, _):
401419 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_context: { devices_in_data_context } "
402420 )
403421 x = wrap_flash_attention (query , key , value )
422+ # Trim back to original sequence length after context-axis padding.
423+ x = x [:, :, :orig_q_seq_len , :]
404424 x = _reshape_heads_to_head_dim (x )
405425
406426 return x
0 commit comments