Skip to content

Commit b8296b2

Browse files
committed
Fix flash attention shard_map for sequence lengths not divisible by context mesh axis
1 parent 2a74af1 commit b8296b2

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,26 @@ def _tpu_flash_attention(
258258
query = _reshape_data_for_flash(query, heads)
259259
key = _reshape_data_for_flash(key, heads)
260260
value = _reshape_data_for_flash(value, heads)
261+
262+
# Pad sequence dimension so it is evenly divisible by the context mesh axis,
263+
# which shard_map requires. The output is trimmed back afterwards and the
264+
# existing segment-ID masking inside wrap_flash_attention ensures padded
265+
# positions do not affect the result.
266+
orig_q_seq_len = query.shape[2]
267+
if num_context_shards > 1:
268+
269+
def _pad_seq_to_context(arr, axis=2):
270+
rem = arr.shape[axis] % num_context_shards
271+
if rem == 0:
272+
return arr
273+
pad_width = [(0, 0)] * arr.ndim
274+
pad_width[axis] = (0, num_context_shards - rem)
275+
return jnp.pad(arr, pad_width)
276+
277+
query = _pad_seq_to_context(query)
278+
key = _pad_seq_to_context(key)
279+
value = _pad_seq_to_context(value)
280+
261281
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
262282
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
263283

@@ -401,6 +421,8 @@ def ring_scan_body(carry, _):
401421
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
402422
)
403423
x = wrap_flash_attention(query, key, value)
424+
# Trim back to original sequence length after context-axis padding.
425+
x = x[:, :, :orig_q_seq_len, :]
404426
x = _reshape_heads_to_head_dim(x)
405427

406428
return x

0 commit comments

Comments
 (0)