|
57 | 57 | CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH |
58 | 58 |
|
59 | 59 |
|
| 60 | +def _coerce_tokamax_block_sizes(block_sizes): |
| 61 | + # Tokamax requires fused bwd; convert if needed. |
| 62 | + if getattr(block_sizes, "use_fused_bwd_kernel", False): |
| 63 | + return block_sizes |
| 64 | + |
| 65 | + # Fall back if some fields are missing. |
| 66 | + bq = block_sizes.block_q |
| 67 | + bkv = getattr(block_sizes, "block_kv", bq) |
| 68 | + bkv_compute = getattr(block_sizes, "block_kv_compute", bkv) |
| 69 | + bq_dkv = getattr(block_sizes, "block_q_dkv", bq) |
| 70 | + bkv_dkv = getattr(block_sizes, "block_kv_dkv", bkv) |
| 71 | + bkv_dkv_compute = getattr(block_sizes, "block_kv_dkv_compute", bkv_compute) |
| 72 | + return splash_attention_kernel.BlockSizes( |
| 73 | + block_q=bq, |
| 74 | + block_kv=bkv, |
| 75 | + block_kv_compute=bkv_compute, |
| 76 | + block_q_dkv=bq_dkv, |
| 77 | + block_kv_dkv=bkv_dkv, |
| 78 | + block_kv_dkv_compute=bkv_dkv_compute, |
| 79 | + block_q_dq=None, |
| 80 | + block_kv_dq=None, |
| 81 | + use_fused_bwd_kernel=True, |
| 82 | + ) |
| 83 | + |
| 84 | + |
60 | 85 | def _maybe_aqt_einsum(quant: Quant): |
61 | 86 | return jnp.einsum if quant is None else quant.einsum() |
62 | 87 |
|
63 | | - |
64 | 88 | def _check_attention_inputs(query: Array, key: Array, value: Array) -> None: |
65 | 89 | """Check attention inputs.""" |
66 | 90 |
|
@@ -231,9 +255,13 @@ def _tpu_flash_attention( |
231 | 255 | kv_max_block_size = key.shape[1] |
232 | 256 | else: |
233 | 257 | kv_max_block_size = q_max_block_size |
| 258 | + |
234 | 259 | # ensure that for cross attention we override the block sizes. |
235 | 260 | if flash_block_sizes and key.shape[1] == query.shape[1]: |
236 | 261 | block_sizes = flash_block_sizes |
| 262 | + use_tokamax = attention_kernel in ["tokamax_flash", "tokamax_ring"] |
| 263 | + if use_tokamax: |
| 264 | + block_sizes = _coerce_tokamax_block_sizes(flash_block_sizes) |
237 | 265 | else: |
238 | 266 | block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size |
239 | 267 | block_sizes = splash_attention_kernel.BlockSizes( |
@@ -327,71 +355,52 @@ def wrap_flash_attention(query, key, value): |
327 | 355 | residual_checkpoint_name=residual_checkpoint_name |
328 | 356 | ) |
329 | 357 |
|
330 | | - if attention_kernel == "tokamax_ring": |
331 | | - # For tokamax_ring, use the kernel directly without vmap |
332 | | - # The ring attention kernel handles the ring topology internally |
333 | | - if not mask_padding_tokens: |
334 | | - segment_ids = None |
335 | | - attention_output = splash_kernel( |
336 | | - fwd_mask_info=None, |
337 | | - dkv_mask_info=None, |
338 | | - q=query, |
339 | | - k=key, |
340 | | - v=value, |
341 | | - segment_ids=segment_ids, |
342 | | - is_mqa=False, |
343 | | - config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), |
344 | | - mask_value=-jnp.inf, |
345 | | - mask_function=None, |
346 | | - fwd_mask_sparsity=1.0, |
347 | | - save_residuals=True, |
348 | | - ) |
349 | | - else: |
350 | | - vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) |
351 | 358 |
|
352 | | - if not mask_padding_tokens: |
353 | | - segment_ids = None |
354 | | - if attention_kernel in ["flash", "tokamax_flash"]: |
355 | | - attention_output = vmapped_splash(query, key, value, segment_ids) |
356 | | - else: |
357 | | - if num_fsdp_shards > 1: |
358 | | - out, (lse,) = vmapped_splash(query, key, value, segment_ids) |
359 | | - m = lse.astype(jnp.float32) |
360 | | - l = jnp.exp(lse - m) |
361 | | - o = out.astype(jnp.float32) * l[..., None] |
| 359 | + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) |
362 | 360 |
|
363 | | - perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] |
| 361 | + if not mask_padding_tokens: |
| 362 | + segment_ids = None |
| 363 | + if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]: |
| 364 | + attention_output = vmapped_splash(query, key, value, segment_ids) |
| 365 | + else: |
| 366 | + if num_fsdp_shards > 1: |
| 367 | + out, (lse,) = vmapped_splash(query, key, value, segment_ids) |
| 368 | + m = lse.astype(jnp.float32) |
| 369 | + l = jnp.exp(lse - m) |
| 370 | + o = out.astype(jnp.float32) * l[..., None] |
364 | 371 |
|
365 | | - k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) |
366 | | - v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) |
| 372 | + perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] |
367 | 373 |
|
368 | | - def ring_scan_body(carry, _): |
369 | | - m, l, o, k_current, v_current = carry |
370 | | - k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) |
371 | | - v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) |
| 374 | + k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) |
| 375 | + v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) |
372 | 376 |
|
373 | | - out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) |
| 377 | + def ring_scan_body(carry, _): |
| 378 | + m, l, o, k_current, v_current = carry |
| 379 | + k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) |
| 380 | + v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) |
374 | 381 |
|
375 | | - m_chunk = lse_chunk.astype(jnp.float32) |
376 | | - m_old = m |
377 | | - m = jnp.maximum(m_old, m_chunk) |
| 382 | + out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) |
378 | 383 |
|
379 | | - exp_m_diff = jnp.exp(m_old - m) |
380 | | - exp_m_chunk_diff = jnp.exp(m_chunk - m) |
| 384 | + m_chunk = lse_chunk.astype(jnp.float32) |
| 385 | + m_old = m |
| 386 | + m = jnp.maximum(m_old, m_chunk) |
381 | 387 |
|
382 | | - l = l * exp_m_diff + jnp.exp(lse_chunk - m) |
383 | | - o = o * exp_m_diff[..., None] |
384 | | - o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) |
| 388 | + exp_m_diff = jnp.exp(m_old - m) |
| 389 | + exp_m_chunk_diff = jnp.exp(m_chunk - m) |
385 | 390 |
|
386 | | - # Return the updated state for the next iteration |
387 | | - return (m, l, o, k_next, v_next), None |
| 391 | + l = l * exp_m_diff + jnp.exp(lse_chunk - m) |
| 392 | + o = o * exp_m_diff[..., None] |
| 393 | + o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) |
388 | 394 |
|
389 | | - initial_carry = (m, l, o, k1, v1) |
390 | | - (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) |
| 395 | + # Return the updated state for the next iteration |
| 396 | + return (m, l, o, k_next, v_next), None |
391 | 397 |
|
392 | | - attention_output = o_final / l_final[..., None] |
393 | | - else: |
394 | | - raise ValueError("ring attention requires fsdp > 1") |
| 398 | + initial_carry = (m, l, o, k1, v1) |
| 399 | + (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) |
| 400 | + |
| 401 | + attention_output = o_final / l_final[..., None] |
| 402 | + else: |
| 403 | + raise ValueError("ring attention requires fsdp > 1") |
395 | 404 |
|
396 | 405 | return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) |
397 | 406 |
|
@@ -571,6 +580,7 @@ def _apply_attention( |
571 | 580 | return _tpu_flash_attention( |
572 | 581 | query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, |
573 | 582 | mask_padding_tokens=mask_padding_tokens, |
| 583 | + residual_checkpoint_name=residual_checkpoint_name, |
574 | 584 | ) |
575 | 585 | elif attention_kernel == "cudnn_flash_te": |
576 | 586 | return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) |
@@ -862,7 +872,8 @@ def __init__( |
862 | 872 | else: |
863 | 873 | axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) |
864 | 874 | axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) |
865 | | - |
| 875 | + if attention_kernel == "tokamax_ring" and not is_self_attention: |
| 876 | + attention_kernel = "tokamax_flash" # do not use ring attention for cross attention |
866 | 877 | self.attention_op = NNXAttentionOp( |
867 | 878 | mesh=mesh, |
868 | 879 | attention_kernel=attention_kernel, |
|
0 commit comments