Skip to content

Commit 71cba8d

Browse files
committed
Remove _resolve_tpu_attention_block_sizes, consolidate into _select_flash_block_sizes
1 parent 292fd84 commit 71cba8d

2 files changed

Lines changed: 19 additions & 54 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -270,38 +270,6 @@ def convert_to_tokamax_splash_config(
270270
)
271271

272272

273-
def _resolve_tpu_attention_block_sizes(
274-
query_seq_len: int,
275-
kv_seq_len: int,
276-
flash_block_sizes: BlockSizes,
277-
dtype: jnp.dtype,
278-
attention_kernel: str = "flash",
279-
) -> BlockSizes:
280-
"""Resolve TPU splash attention block sizes for self- and cross-attention."""
281-
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
282-
is_cross_attention = kv_seq_len != query_seq_len
283-
if is_cross_attention:
284-
kv_max_block_size = ((kv_seq_len + 127) // 128) * 128
285-
else:
286-
kv_max_block_size = q_max_block_size
287-
288-
if flash_block_sizes and not is_cross_attention:
289-
return flash_block_sizes
290-
291-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
292-
return splash_attention_kernel.BlockSizes(
293-
block_q=block_size_q,
294-
block_kv_compute=min(kv_max_block_size, kv_seq_len),
295-
block_kv=min(kv_max_block_size, kv_seq_len),
296-
block_q_dkv=block_size_q,
297-
block_kv_dkv=min(kv_max_block_size, kv_seq_len),
298-
block_kv_dkv_compute=min(kv_max_block_size, query_seq_len),
299-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
300-
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query_seq_len),
301-
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
302-
)
303-
304-
305273
def _tpu_flash_attention(
306274
query: jax.Array,
307275
key: jax.Array,
@@ -319,18 +287,11 @@ def _tpu_flash_attention(
319287
) -> jax.Array:
320288
"""TPU Flash Attention"""
321289

322-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
323290
num_context_shards = mesh.shape["context"]
324291
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
325292
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
326293
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
327-
block_sizes = _resolve_tpu_attention_block_sizes(
328-
query_seq_len=query.shape[2],
329-
kv_seq_len=key.shape[2],
330-
flash_block_sizes=flash_block_sizes,
331-
dtype=dtype,
332-
attention_kernel=attention_kernel,
333-
)
294+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
334295

335296
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
336297
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
@@ -530,12 +491,7 @@ def _ulysses_attention(
530491
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
531492
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
532493

533-
block_sizes = _resolve_tpu_attention_block_sizes(
534-
query_seq_len=query.shape[2],
535-
kv_seq_len=key.shape[2],
536-
flash_block_sizes=flash_block_sizes,
537-
dtype=dtype,
538-
)
494+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
539495

540496
@functools.partial(
541497
jax.shard_map,

src/maxdiffusion/tests/attention_test.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,21 +186,30 @@ def test_default_flash_block_sizes_use_sequence_axis_for_3d_inputs(self):
186186
assert block_sizes.block_q_dq == 1024
187187
assert block_sizes.block_kv_dq == 128
188188

189-
def test_resolve_tpu_attention_block_sizes(self):
190-
"""Shared block-size selection should keep self-attn overrides and derive cross-attn defaults."""
189+
def test_select_flash_block_sizes_returns_configured_for_self_attention(self):
190+
"""Block-size selection should return the configured sizes unchanged for self-attention."""
191191
custom_block_sizes = self._ulysses_block_sizes(block_size=16)
192+
query = jnp.zeros((1, 128, 1), dtype=jnp.float32)
193+
key = jnp.zeros((1, 128, 1), dtype=jnp.float32)
192194

193-
self_attention_block_sizes = attention_flax._resolve_tpu_attention_block_sizes(
194-
query_seq_len=128,
195-
kv_seq_len=128,
195+
self_attention_block_sizes = _select_flash_block_sizes(
196+
query=query,
197+
key=key,
196198
flash_block_sizes=custom_block_sizes,
197199
dtype=jnp.float32,
200+
attention_kernel="flash",
198201
)
199202
self.assertIs(self_attention_block_sizes, custom_block_sizes)
200203

201-
cross_attention_block_sizes = attention_flax._resolve_tpu_attention_block_sizes(
202-
query_seq_len=257,
203-
kv_seq_len=513,
204+
def test_select_flash_block_sizes_derives_cross_attn_defaults_for_tokamax(self):
205+
"""Block-size selection should derive cross-attn defaults and set tokamax_flash flags."""
206+
custom_block_sizes = self._ulysses_block_sizes(block_size=16)
207+
query = jnp.zeros((1, 257, 1), dtype=jnp.float32)
208+
key = jnp.zeros((1, 513, 1), dtype=jnp.float32)
209+
210+
cross_attention_block_sizes = _select_flash_block_sizes(
211+
query=query,
212+
key=key,
204213
flash_block_sizes=custom_block_sizes,
205214
dtype=jnp.float32,
206215
attention_kernel="tokamax_flash",

0 commit comments

Comments
 (0)