|
21 | 21 | import jax |
22 | 22 | from jax.ad_checkpoint import checkpoint_name |
23 | 23 | import jax.numpy as jnp |
| 24 | +from jax.experimental import shard_map |
24 | 25 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask |
25 | 26 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel |
| 27 | +from maxdiffusion.kernels.splash_attention import splash_attention_mask as tokamax_splash_attention_mask |
26 | 28 | from maxdiffusion.kernels.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel |
| 29 | +from maxdiffusion.kernels.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel |
| 30 | +from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base |
27 | 31 | from einops import rearrange |
28 | 32 | from .. import common_types, max_logging |
29 | 33 |
|
30 | 34 | from . import custom_splash_attention as custom_splash |
31 | | - |
32 | 35 | from . import quantizations |
33 | 36 | from .modeling_flax_utils import get_activation |
34 | 37 |
|
@@ -311,42 +314,191 @@ def _tpu_flash_attention( |
311 | 314 | flash_block_sizes: BlockSizes, |
312 | 315 | dtype: jnp.dtype = jnp.float32, |
313 | 316 | attention_kernel: str = "flash", |
314 | | - mask_padding_tokens: bool = False, |
| 317 | + mask_padding_tokens: bool = True, |
315 | 318 | residual_checkpoint_name: str | None = None, |
316 | 319 | attention_mask: jax.Array = None, |
317 | 320 | use_base2_exp: bool = False, |
318 | 321 | use_experimental_scheduler: bool = False, |
319 | 322 | ) -> jax.Array: |
320 | | - """Torchax-Style Tensor Parallel TPU Flash Attention""" |
321 | | - |
322 | | - # 1. Reshape from (Batch, Seq_Len, Heads*Dim) -> (Batch, Heads, Seq_Len, Dim) |
323 | | - # We pass num_context_shards=1 because we hold the full sequence locally. |
324 | | - query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards=1) |
325 | | - key, _ = _reshape_data_for_flash(key, heads, num_context_shards=1) |
326 | | - value, _ = _reshape_data_for_flash(value, heads, num_context_shards=1) |
327 | | - |
328 | | - # 2. Call the built-in Torchax shard_map wrapper! |
329 | | - # This automatically handles the SPMD boundaries, the min() padding logic for |
330 | | - # cross-attention, and the axis swapping. |
331 | | - x = custom_splash.tpu_custom_attention( |
332 | | - query=query, |
333 | | - key=key, |
334 | | - value=value, |
| 323 | + """TPU Flash Attention""" |
| 324 | + |
| 325 | + num_context_shards = mesh.shape["context"] |
| 326 | + query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) |
| 327 | + key, _ = _reshape_data_for_flash(key, heads, num_context_shards) |
| 328 | + value, _ = _reshape_data_for_flash(value, heads, num_context_shards) |
| 329 | + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) |
| 330 | + |
| 331 | + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) |
| 332 | + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) |
| 333 | + |
| 334 | + @functools.partial( |
| 335 | + shard_map.shard_map, |
335 | 336 | mesh=mesh, |
336 | | - # CRITICAL: JAX already scaled `key` by 1/sqrt(d) in the router. |
337 | | - # Passing 1.0 ensures the kernel only applies the log2(e) multiplier to `query`. |
338 | | - scale=1.0, |
339 | | - block_q=2048, |
340 | | - block_kv=2048, |
341 | | - block_kv_compute=1024, |
342 | | - block_kv_compute_in=256, |
343 | | - heads_per_tile=1, |
| 337 | + in_specs=(q_axis_names, kv_axis_names, kv_axis_names), |
| 338 | + out_specs=q_axis_names, |
| 339 | + check_rep=False, |
344 | 340 | ) |
| 341 | + def wrap_flash_attention(query, key, value): |
| 342 | + uses_fused_kernel = block_sizes.use_fused_bwd_kernel |
| 343 | + block_q_sizes = ( |
| 344 | + block_sizes.block_q, |
| 345 | + block_sizes.block_q_dkv, |
| 346 | + ) |
| 347 | + block_kv_sizes = ( |
| 348 | + block_sizes.block_kv, |
| 349 | + block_sizes.block_kv_dkv, |
| 350 | + ) |
| 351 | + if uses_fused_kernel: |
| 352 | + block_q_sizes += (block_sizes.block_q_dkv,) |
| 353 | + block_kv_sizes += (block_sizes.block_kv_dkv,) |
| 354 | + else: |
| 355 | + block_q_sizes += (block_sizes.block_q_dq,) |
| 356 | + block_kv_sizes += (block_sizes.block_kv_dq,) |
345 | 357 |
|
346 | | - # 3. Trim back to original sequence length |
347 | | - x = x[:, :, :orig_q_seq_len, :] |
| 358 | + block_q = max(*block_q_sizes) |
| 359 | + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) |
348 | 360 |
|
349 | | - # 4. Flatten back to (Batch, Seq_Len, Heads * Head_Dim) |
| 361 | + block_kv = max(*block_kv_sizes) |
| 362 | + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) |
| 363 | + value, _, _ = _pad_data_for_flash(value, heads, block_kv) |
| 364 | + |
| 365 | + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) |
| 366 | + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) |
| 367 | + |
| 368 | + q_padded_len = query.shape[2] |
| 369 | + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) |
| 370 | + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) |
| 371 | + |
| 372 | + kv_padded_len = key.shape[2] |
| 373 | + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) |
| 374 | + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) |
| 375 | + |
| 376 | + # If attention_mask is provided, apply it to kv_segment_ids |
| 377 | + if attention_mask is not None: |
| 378 | + mask_len = min(key_seq_len, attention_mask.shape[1]) |
| 379 | + kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,) |
| 380 | + # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid) |
| 381 | + if key_seq_len > mask_len: |
| 382 | + extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) |
| 383 | + kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,) |
| 384 | + # Pad to kv_padded_len |
| 385 | + if kv_padded_len > key_seq_len: |
| 386 | + padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) |
| 387 | + kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,) |
| 388 | + else: |
| 389 | + kv_mask_padded = kv_mask_for_batch |
| 390 | + # Both are (kv_padded_len,) - element-wise multiplication |
| 391 | + kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) |
| 392 | + |
| 393 | + if attention_kernel == "tokamax_ring": |
| 394 | + segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) |
| 395 | + else: |
| 396 | + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) |
| 397 | + |
| 398 | + # make_splash_mha is wrapped around shardmap and seq and head is already |
| 399 | + # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. |
| 400 | + if attention_kernel == "tokamax_flash": |
| 401 | + mask = tokamax_splash_attention_mask.FullMask( |
| 402 | + _shape=(query.shape[2], key.shape[2]), |
| 403 | + ) |
| 404 | + splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( |
| 405 | + mask=mask, |
| 406 | + q_seq_shards=1, # the sizes of the axis is sharding over seq_len |
| 407 | + config=convert_to_tokamax_splash_config( |
| 408 | + block_sizes, |
| 409 | + residual_checkpoint_name=residual_checkpoint_name, |
| 410 | + use_base2_exp=use_base2_exp, |
| 411 | + use_experimental_scheduler=use_experimental_scheduler, |
| 412 | + ), |
| 413 | + save_residuals=False, |
| 414 | + ) |
| 415 | + elif attention_kernel == "tokamax_ring": |
| 416 | + mask = tokamax_splash_attention_mask.FullMask( |
| 417 | + _shape=(query.shape[2], key.shape[2]), |
| 418 | + ) |
| 419 | + splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( |
| 420 | + mask=mask, |
| 421 | + is_mqa=False, |
| 422 | + config=convert_to_tokamax_splash_config( |
| 423 | + block_sizes, |
| 424 | + residual_checkpoint_name=residual_checkpoint_name, |
| 425 | + use_base2_exp=use_base2_exp, |
| 426 | + use_experimental_scheduler=use_experimental_scheduler, |
| 427 | + ), |
| 428 | + save_residuals=False, |
| 429 | + ring_axis="context", |
| 430 | + rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids |
| 431 | + ) |
| 432 | + else: |
| 433 | + splash_kernel = splash_attention_kernel.make_splash_mha( |
| 434 | + mask=multi_head_mask, |
| 435 | + head_shards=1, # the sizes of the axis is sharding over heads |
| 436 | + q_seq_shards=1, # the sizes of the axis is sharding over seq_len |
| 437 | + block_sizes=block_sizes, |
| 438 | + save_residuals=True if "ring" in attention_kernel else False, |
| 439 | + residual_checkpoint_name=residual_checkpoint_name, |
| 440 | + ) |
| 441 | + |
| 442 | + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) |
| 443 | + |
| 444 | + if not mask_padding_tokens: |
| 445 | + segment_ids = None |
| 446 | + if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]: |
| 447 | + attention_output = vmapped_splash(query, key, value, segment_ids) |
| 448 | + else: |
| 449 | + if num_context_shards > 1: |
| 450 | + out, (lse,) = vmapped_splash(query, key, value, segment_ids) |
| 451 | + m = lse.astype(jnp.float32) |
| 452 | + l = jnp.exp(lse - m) |
| 453 | + o = out.astype(jnp.float32) * l[..., None] |
| 454 | + |
| 455 | + perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)] |
| 456 | + |
| 457 | + k1 = jax.lax.ppermute(key, axis_name="context", perm=perm) |
| 458 | + v1 = jax.lax.ppermute(value, axis_name="context", perm=perm) |
| 459 | + |
| 460 | + def ring_scan_body(carry, _): |
| 461 | + m, l, o, k_current, v_current = carry |
| 462 | + k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm) |
| 463 | + v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm) |
| 464 | + |
| 465 | + out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) |
| 466 | + |
| 467 | + m_chunk = lse_chunk.astype(jnp.float32) |
| 468 | + m_old = m |
| 469 | + m = jnp.maximum(m_old, m_chunk) |
| 470 | + |
| 471 | + exp_m_diff = jnp.exp(m_old - m) |
| 472 | + exp_m_chunk_diff = jnp.exp(m_chunk - m) |
| 473 | + |
| 474 | + l = l * exp_m_diff + jnp.exp(lse_chunk - m) |
| 475 | + o = o * exp_m_diff[..., None] |
| 476 | + o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) |
| 477 | + |
| 478 | + # Return the updated state for the next iteration |
| 479 | + return (m, l, o, k_next, v_next), None |
| 480 | + |
| 481 | + initial_carry = (m, l, o, k1, v1) |
| 482 | + (m_final, l_final, o_final, _, _), _ = jax.lax.scan( |
| 483 | + ring_scan_body, initial_carry, None, length=num_context_shards - 1 |
| 484 | + ) |
| 485 | + |
| 486 | + attention_output = o_final / l_final[..., None] |
| 487 | + else: |
| 488 | + raise ValueError("ring attention requires context > 1") |
| 489 | + return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) |
| 490 | + |
| 491 | + devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) |
| 492 | + # This warning might show up when doing model eval for example, when calculating model flops |
| 493 | + # and that is expected. |
| 494 | + if not (query.shape[0] / devices_in_batch_sharding).is_integer(): |
| 495 | + max_logging.log( |
| 496 | + "Warning, batch dimension should be shardable among the devices in data and fsdp" |
| 497 | + f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" |
| 498 | + ) |
| 499 | + x = wrap_flash_attention(query, key, value) |
| 500 | + # Trim back to original sequence length after context-axis padding. |
| 501 | + x = x[:, :, :orig_q_seq_len, :] |
350 | 502 | x = _reshape_heads_to_head_dim(x) |
351 | 503 |
|
352 | 504 | return x |
@@ -429,6 +581,7 @@ def wrap_ulysses_attention(query, key, value): |
429 | 581 | block_kv = max(*block_kv_sizes) |
430 | 582 | key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) |
431 | 583 | value, _, _ = _pad_data_for_flash(value, heads, block_kv) |
| 584 | + |
432 | 585 | mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) |
433 | 586 | multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) |
434 | 587 |
|
@@ -577,10 +730,6 @@ def wrap_ulysses_attention(query, key, value): |
577 | 730 | return x |
578 | 731 |
|
579 | 732 |
|
580 | | -def torchax_attention(): |
581 | | - pass |
582 | | - |
583 | | - |
584 | 733 | def _apply_attention_dot( |
585 | 734 | query: Array, |
586 | 735 | key: Array, |
@@ -692,7 +841,7 @@ def _apply_attention( |
692 | 841 | axis_names_kv: AxisNames, |
693 | 842 | flash_block_sizes: BlockSizes, |
694 | 843 | dpa_layer: Callable, |
695 | | - mask_padding_tokens: bool = False, |
| 844 | + mask_padding_tokens: bool = True, |
696 | 845 | residual_checkpoint_name: str | None = None, |
697 | 846 | attention_mask: Array = None, |
698 | 847 | use_base2_exp: bool = False, |
@@ -954,7 +1103,7 @@ def __init__( |
954 | 1103 | flash_block_sizes: BlockSizes = None, |
955 | 1104 | dtype: DType = jnp.float32, |
956 | 1105 | quant: Quant = None, |
957 | | - mask_padding_tokens: bool = False, |
| 1106 | + mask_padding_tokens: bool = True, |
958 | 1107 | residual_checkpoint_name: str | None = None, |
959 | 1108 | use_base2_exp: bool = False, |
960 | 1109 | use_experimental_scheduler: bool = False, |
@@ -1122,7 +1271,7 @@ def __init__( |
1122 | 1271 | qkv_bias: bool = False, |
1123 | 1272 | quant: Quant = None, |
1124 | 1273 | is_self_attention: bool = True, |
1125 | | - mask_padding_tokens: bool = False, |
| 1274 | + mask_padding_tokens: bool = True, |
1126 | 1275 | residual_checkpoint_name: str | None = None, |
1127 | 1276 | enable_jax_named_scopes: bool = False, |
1128 | 1277 | added_kv_proj_dim: Optional[int] = None, # New for I2V |
|
0 commit comments