Skip to content

Commit 56c2fc7

Browse files
committed
fix ulysses attn error
1 parent 3d50d6f commit 56c2fc7

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,11 @@ def _tpu_flash_attention(
287287
) -> jax.Array:
288288
"""TPU Flash Attention"""
289289

290-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
291290
num_context_shards = mesh.shape["context"]
292291
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
293292
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
294293
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
294+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
295295

296296
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
297297
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

0 commit comments

Comments
 (0)