@@ -128,15 +128,25 @@ def _unflatten_heads(tensor, heads):
128128 return tensor
129129
130130
131- def _reshape_data_for_flash (tensor , heads ):
131+ def _reshape_data_for_flash (tensor , heads , num_context_shards = 1 ):
132132 """
133133 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
134134 Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
135135 blocks is divisible by the number of shards.
136136 """
137137 if tensor .ndim != 4 :
138138 tensor = _unflatten_heads (tensor , heads )
139- return tensor
139+
140+ # Pad sequence dimension so it is evenly divisible by the context mesh axis,
141+ # which shard_map requires.
142+ if num_context_shards <= 1 :
143+ return tensor
144+ rem = tensor .shape [2 ] % num_context_shards
145+ if rem == 0 :
146+ return tensor
147+ pad_width = [(0 , 0 )] * tensor .ndim
148+ pad_width [2 ] = (0 , num_context_shards - rem )
149+ return jnp .pad (tensor , pad_width )
140150
141151
142152def _pad_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
@@ -255,9 +265,11 @@ def _tpu_flash_attention(
255265 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
256266 )
257267 num_context_shards = mesh .shape ["context" ]
258- query = _reshape_data_for_flash (query , heads )
259- key = _reshape_data_for_flash (key , heads )
260- value = _reshape_data_for_flash (value , heads )
268+ orig_q_seq_len = query .shape [1 ]
269+ query = _reshape_data_for_flash (query , heads , num_context_shards )
270+ key = _reshape_data_for_flash (key , heads , num_context_shards )
271+ value = _reshape_data_for_flash (value , heads , num_context_shards )
272+
261273 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
262274 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
263275
@@ -401,6 +413,8 @@ def ring_scan_body(carry, _):
401413 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_context: { devices_in_data_context } "
402414 )
403415 x = wrap_flash_attention (query , key , value )
416+ # Trim back to original sequence length after context-axis padding.
417+ x = x [:, :, :orig_q_seq_len , :]
404418 x = _reshape_heads_to_head_dim (x )
405419
406420 return x
0 commit comments