|
23 | 23 | import jax.numpy as jnp |
24 | 24 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask |
25 | 25 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel |
26 | | -from maxdiffusion.kernels.splash_attention import splash_attention_mask as tokamax_splash_attention_mask |
27 | 26 | from maxdiffusion.kernels.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel |
28 | | -from maxdiffusion.kernels.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel |
29 | | -from jax.experimental.shard_map import shard_map |
30 | | -from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base |
31 | 27 | from einops import rearrange |
32 | 28 | from .. import common_types, max_logging |
33 | 29 |
|
@@ -315,179 +311,42 @@ def _tpu_flash_attention( |
315 | 311 | flash_block_sizes: BlockSizes, |
316 | 312 | dtype: jnp.dtype = jnp.float32, |
317 | 313 | attention_kernel: str = "flash", |
318 | | - mask_padding_tokens: bool = True, |
| 314 | + mask_padding_tokens: bool = False, |
319 | 315 | residual_checkpoint_name: str | None = None, |
320 | 316 | attention_mask: jax.Array = None, |
321 | 317 | use_base2_exp: bool = False, |
322 | 318 | use_experimental_scheduler: bool = False, |
323 | 319 | ) -> jax.Array: |
324 | | - """TPU Flash Attention""" |
325 | | - num_context_shards = mesh.shape["context"] |
326 | | - query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) |
327 | | - key, _ = _reshape_data_for_flash(key, heads, num_context_shards) |
328 | | - value, _ = _reshape_data_for_flash(value, heads, num_context_shards) |
329 | | - block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) |
330 | | - |
331 | | - q_axis_names = nn.logical_to_mesh_axes(axis_names_q) |
332 | | - kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) |
333 | | - |
334 | | - @functools.partial( |
335 | | - shard_map, |
| 320 | + """Torchax-Style Tensor Parallel TPU Flash Attention""" |
| 321 | + |
| 322 | + # 1. Reshape from (Batch, Seq_Len, Heads*Dim) -> (Batch, Heads, Seq_Len, Dim) |
| 323 | + # We pass num_context_shards=1 because we hold the full sequence locally. |
| 324 | + query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards=1) |
| 325 | + key, _ = _reshape_data_for_flash(key, heads, num_context_shards=1) |
| 326 | + value, _ = _reshape_data_for_flash(value, heads, num_context_shards=1) |
| 327 | + |
| 328 | + # 2. Call the built-in Torchax shard_map wrapper! |
| 329 | + # This automatically handles the SPMD boundaries, the min() padding logic for |
| 330 | + # cross-attention, and the axis swapping. |
| 331 | + x = custom_splash.tpu_custom_attention( |
| 332 | + query=query, |
| 333 | + key=key, |
| 334 | + value=value, |
336 | 335 | mesh=mesh, |
337 | | - in_specs=(q_axis_names, kv_axis_names, kv_axis_names), |
338 | | - out_specs=q_axis_names, |
339 | | - check_rep=False, |
| 336 | + # CRITICAL: JAX already scaled `key` by 1/sqrt(d) in the router. |
| 337 | + # Passing 1.0 ensures the kernel only applies the log2(e) multiplier to `query`. |
| 338 | + scale=1.0, |
| 339 | + block_q=2048, |
| 340 | + block_kv=2048, |
| 341 | + block_kv_compute=1024, |
| 342 | + block_kv_compute_in=256, |
| 343 | + heads_per_tile=1, |
340 | 344 | ) |
341 | | - def wrap_flash_attention(query, key, value): |
342 | | - uses_fused_kernel = block_sizes.use_fused_bwd_kernel |
343 | | - block_q_sizes = ( |
344 | | - block_sizes.block_q, |
345 | | - block_sizes.block_q_dkv, |
346 | | - ) |
347 | | - block_kv_sizes = ( |
348 | | - block_sizes.block_kv, |
349 | | - block_sizes.block_kv_dkv, |
350 | | - ) |
351 | | - if uses_fused_kernel: |
352 | | - block_q_sizes += (block_sizes.block_q_dkv,) |
353 | | - block_kv_sizes += (block_sizes.block_kv_dkv,) |
354 | | - else: |
355 | | - block_q_sizes += (block_sizes.block_q_dq,) |
356 | | - block_kv_sizes += (block_sizes.block_kv_dq,) |
357 | | - block_q = max(*block_q_sizes) |
358 | | - query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) |
359 | | - |
360 | | - block_kv = max(*block_kv_sizes) |
361 | | - key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) |
362 | | - value, _, _ = _pad_data_for_flash(value, heads, block_kv) |
363 | | - |
364 | | - mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) |
365 | | - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) |
366 | | - |
367 | | - q_padded_len = query.shape[2] |
368 | | - q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) |
369 | | - q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) |
370 | | - |
371 | | - kv_padded_len = key.shape[2] |
372 | | - kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) |
373 | | - kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) |
374 | | - |
375 | | - if attention_mask is not None: |
376 | | - mask_len = min(key_seq_len, attention_mask.shape[1]) |
377 | | - kv_mask_for_batch = attention_mask[0, :mask_len] |
378 | | - if key_seq_len > mask_len: |
379 | | - extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) |
380 | | - kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) |
381 | | - if kv_padded_len > key_seq_len: |
382 | | - padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) |
383 | | - kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) |
384 | | - else: |
385 | | - kv_mask_padded = kv_mask_for_batch |
386 | | - kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) |
387 | | - |
388 | | - if attention_kernel == "tokamax_ring": |
389 | | - segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) |
390 | | - else: |
391 | | - segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) |
392 | | - |
393 | | - if attention_kernel == "tokamax_flash": |
394 | | - mask = tokamax_splash_attention_mask.FullMask( |
395 | | - _shape=(query.shape[2], key.shape[2]), |
396 | | - ) |
397 | | - splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( |
398 | | - mask=mask, |
399 | | - q_seq_shards=1, |
400 | | - config=convert_to_tokamax_splash_config( |
401 | | - block_sizes, |
402 | | - residual_checkpoint_name=residual_checkpoint_name, |
403 | | - use_base2_exp=use_base2_exp, |
404 | | - use_experimental_scheduler=use_experimental_scheduler, |
405 | | - ), |
406 | | - save_residuals=False, |
407 | | - ) |
408 | | - elif attention_kernel == "tokamax_ring": |
409 | | - mask = tokamax_splash_attention_mask.FullMask( |
410 | | - _shape=(query.shape[2], key.shape[2]), |
411 | | - ) |
412 | | - splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( |
413 | | - mask=mask, |
414 | | - is_mqa=False, |
415 | | - config=convert_to_tokamax_splash_config( |
416 | | - block_sizes, |
417 | | - residual_checkpoint_name=residual_checkpoint_name, |
418 | | - use_base2_exp=use_base2_exp, |
419 | | - use_experimental_scheduler=use_experimental_scheduler, |
420 | | - ), |
421 | | - save_residuals=False, |
422 | | - ring_axis="context", |
423 | | - rotate_segment_ids=False, |
424 | | - ) |
425 | | - else: |
426 | | - splash_kernel = splash_attention_kernel.make_splash_mha( |
427 | | - mask=multi_head_mask, |
428 | | - head_shards=1, # the sizes of the axis is sharding over heads |
429 | | - q_seq_shards=1, # the sizes of the axis is sharding over seq_len |
430 | | - block_sizes=block_sizes, |
431 | | - save_residuals=True if "ring" in attention_kernel else False, |
432 | | - residual_checkpoint_name=residual_checkpoint_name, |
433 | | - ) |
434 | | - |
435 | | - vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) |
436 | | - |
437 | | - if not mask_padding_tokens: |
438 | | - segment_ids = None |
439 | | - if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]: |
440 | | - attention_output = vmapped_splash(query, key, value, segment_ids) |
441 | | - else: |
442 | | - if num_context_shards > 1: |
443 | | - out, (lse,) = vmapped_splash(query, key, value, segment_ids) |
444 | | - m = lse.astype(jnp.float32) |
445 | | - l = jnp.exp(lse - m) |
446 | | - o = out.astype(jnp.float32) * l[..., None] |
447 | | - |
448 | | - perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)] |
449 | | - |
450 | | - k1 = jax.lax.ppermute(key, axis_name="context", perm=perm) |
451 | | - v1 = jax.lax.ppermute(value, axis_name="context", perm=perm) |
452 | | - |
453 | | - def ring_scan_body(carry, _): |
454 | | - m, l, o, k_current, v_current = carry |
455 | | - k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm) |
456 | | - v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm) |
457 | | - |
458 | | - out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) |
459 | 345 |
|
460 | | - m_chunk = lse_chunk.astype(jnp.float32) |
461 | | - m_old = m |
462 | | - m = jnp.maximum(m_old, m_chunk) |
463 | | - |
464 | | - exp_m_diff = jnp.exp(m_old - m) |
465 | | - exp_m_chunk_diff = jnp.exp(m_chunk - m) |
466 | | - |
467 | | - l = l * exp_m_diff + jnp.exp(lse_chunk - m) |
468 | | - o = o * exp_m_diff[..., None] |
469 | | - o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) |
470 | | - |
471 | | - return (m, l, o, k_next, v_next), None |
472 | | - |
473 | | - initial_carry = (m, l, o, k1, v1) |
474 | | - (m_final, l_final, o_final, _, _), _ = jax.lax.scan( |
475 | | - ring_scan_body, initial_carry, None, length=num_context_shards - 1 |
476 | | - ) |
477 | | - |
478 | | - attention_output = o_final / l_final[..., None] |
479 | | - else: |
480 | | - raise ValueError("ring attention requires context > 1") |
481 | | - return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) |
482 | | - |
483 | | - devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) |
484 | | - if not (query.shape[0] / devices_in_batch_sharding).is_integer(): |
485 | | - max_logging.log( |
486 | | - "Warning, batch dimension should be shardable among the devices in data and fsdp" |
487 | | - f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" |
488 | | - ) |
489 | | - x = wrap_flash_attention(query, key, value) |
| 346 | + # 3. Trim back to original sequence length |
490 | 347 | x = x[:, :, :orig_q_seq_len, :] |
| 348 | + |
| 349 | + # 4. Flatten back to (Batch, Seq_Len, Heads * Head_Dim) |
491 | 350 | x = _reshape_heads_to_head_dim(x) |
492 | 351 |
|
493 | 352 | return x |
|
0 commit comments