Skip to content

Commit e22ec2c

Browse files
committed
Fix flash attention shard_map for sequence lengths not divisible by context mesh axis
1 parent 2a74af1 commit e22ec2c

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,24 @@ def _tpu_flash_attention(
258258
query = _reshape_data_for_flash(query, heads)
259259
key = _reshape_data_for_flash(key, heads)
260260
value = _reshape_data_for_flash(value, heads)
261+
262+
# Pad sequence dimension so it is evenly divisible by the context mesh axis,
263+
# which shard_map requires. The output is trimmed back afterwards and the
264+
# existing segment-ID masking inside wrap_flash_attention ensures padded
265+
# positions do not affect the result.
266+
orig_q_seq_len = query.shape[2]
267+
if num_context_shards > 1:
268+
def _pad_seq_to_context(arr, axis=2):
269+
rem = arr.shape[axis] % num_context_shards
270+
if rem == 0:
271+
return arr
272+
pad_width = [(0, 0)] * arr.ndim
273+
pad_width[axis] = (0, num_context_shards - rem)
274+
return jnp.pad(arr, pad_width)
275+
query = _pad_seq_to_context(query)
276+
key = _pad_seq_to_context(key)
277+
value = _pad_seq_to_context(value)
278+
261279
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
262280
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
263281

@@ -401,6 +419,8 @@ def ring_scan_body(carry, _):
401419
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
402420
)
403421
x = wrap_flash_attention(query, key, value)
422+
# Trim back to original sequence length after context-axis padding.
423+
x = x[:, :, :orig_q_seq_len, :]
404424
x = _reshape_heads_to_head_dim(x)
405425

406426
return x

0 commit comments

Comments
 (0)