Skip to content

Commit 181c0fa

Browse files
committed
rewrite logic
1 parent 1a0f833 commit 181c0fa

1 file changed

Lines changed: 184 additions & 35 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 184 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
import jax.numpy as jnp
24+
from jax.experimental import shard_map
2425
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2526
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
27+
from maxdiffusion.kernels.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2628
from maxdiffusion.kernels.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
29+
from maxdiffusion.kernels.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
30+
from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base
2731
from einops import rearrange
2832
from .. import common_types, max_logging
2933

3034
from . import custom_splash_attention as custom_splash
31-
3235
from . import quantizations
3336
from .modeling_flax_utils import get_activation
3437

@@ -311,42 +314,191 @@ def _tpu_flash_attention(
311314
flash_block_sizes: BlockSizes,
312315
dtype: jnp.dtype = jnp.float32,
313316
attention_kernel: str = "flash",
314-
mask_padding_tokens: bool = False,
317+
mask_padding_tokens: bool = True,
315318
residual_checkpoint_name: str | None = None,
316319
attention_mask: jax.Array = None,
317320
use_base2_exp: bool = False,
318321
use_experimental_scheduler: bool = False,
319322
) -> jax.Array:
320-
"""Torchax-Style Tensor Parallel TPU Flash Attention"""
321-
322-
# 1. Reshape from (Batch, Seq_Len, Heads*Dim) -> (Batch, Heads, Seq_Len, Dim)
323-
# We pass num_context_shards=1 because we hold the full sequence locally.
324-
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards=1)
325-
key, _ = _reshape_data_for_flash(key, heads, num_context_shards=1)
326-
value, _ = _reshape_data_for_flash(value, heads, num_context_shards=1)
327-
328-
# 2. Call the built-in Torchax shard_map wrapper!
329-
# This automatically handles the SPMD boundaries, the min() padding logic for
330-
# cross-attention, and the axis swapping.
331-
x = custom_splash.tpu_custom_attention(
332-
query=query,
333-
key=key,
334-
value=value,
323+
"""TPU Flash Attention"""
324+
325+
num_context_shards = mesh.shape["context"]
326+
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
327+
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
328+
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
329+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
330+
331+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
332+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
333+
334+
@functools.partial(
335+
shard_map.shard_map,
335336
mesh=mesh,
336-
# CRITICAL: JAX already scaled `key` by 1/sqrt(d) in the router.
337-
# Passing 1.0 ensures the kernel only applies the log2(e) multiplier to `query`.
338-
scale=1.0,
339-
block_q=2048,
340-
block_kv=2048,
341-
block_kv_compute=1024,
342-
block_kv_compute_in=256,
343-
heads_per_tile=1,
337+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
338+
out_specs=q_axis_names,
339+
check_rep=False,
344340
)
341+
def wrap_flash_attention(query, key, value):
342+
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
343+
block_q_sizes = (
344+
block_sizes.block_q,
345+
block_sizes.block_q_dkv,
346+
)
347+
block_kv_sizes = (
348+
block_sizes.block_kv,
349+
block_sizes.block_kv_dkv,
350+
)
351+
if uses_fused_kernel:
352+
block_q_sizes += (block_sizes.block_q_dkv,)
353+
block_kv_sizes += (block_sizes.block_kv_dkv,)
354+
else:
355+
block_q_sizes += (block_sizes.block_q_dq,)
356+
block_kv_sizes += (block_sizes.block_kv_dq,)
345357

346-
# 3. Trim back to original sequence length
347-
x = x[:, :, :orig_q_seq_len, :]
358+
block_q = max(*block_q_sizes)
359+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
348360

349-
# 4. Flatten back to (Batch, Seq_Len, Heads * Head_Dim)
361+
block_kv = max(*block_kv_sizes)
362+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
363+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
364+
365+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
366+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
367+
368+
q_padded_len = query.shape[2]
369+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
370+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
371+
372+
kv_padded_len = key.shape[2]
373+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
374+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
375+
376+
# If attention_mask is provided, apply it to kv_segment_ids
377+
if attention_mask is not None:
378+
mask_len = min(key_seq_len, attention_mask.shape[1])
379+
kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,)
380+
# If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid)
381+
if key_seq_len > mask_len:
382+
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
383+
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,)
384+
# Pad to kv_padded_len
385+
if kv_padded_len > key_seq_len:
386+
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
387+
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,)
388+
else:
389+
kv_mask_padded = kv_mask_for_batch
390+
# Both are (kv_padded_len,) - element-wise multiplication
391+
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
392+
393+
if attention_kernel == "tokamax_ring":
394+
segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
395+
else:
396+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
397+
398+
# make_splash_mha is wrapped around shardmap and seq and head is already
399+
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
400+
if attention_kernel == "tokamax_flash":
401+
mask = tokamax_splash_attention_mask.FullMask(
402+
_shape=(query.shape[2], key.shape[2]),
403+
)
404+
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
405+
mask=mask,
406+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
407+
config=convert_to_tokamax_splash_config(
408+
block_sizes,
409+
residual_checkpoint_name=residual_checkpoint_name,
410+
use_base2_exp=use_base2_exp,
411+
use_experimental_scheduler=use_experimental_scheduler,
412+
),
413+
save_residuals=False,
414+
)
415+
elif attention_kernel == "tokamax_ring":
416+
mask = tokamax_splash_attention_mask.FullMask(
417+
_shape=(query.shape[2], key.shape[2]),
418+
)
419+
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
420+
mask=mask,
421+
is_mqa=False,
422+
config=convert_to_tokamax_splash_config(
423+
block_sizes,
424+
residual_checkpoint_name=residual_checkpoint_name,
425+
use_base2_exp=use_base2_exp,
426+
use_experimental_scheduler=use_experimental_scheduler,
427+
),
428+
save_residuals=False,
429+
ring_axis="context",
430+
rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
431+
)
432+
else:
433+
splash_kernel = splash_attention_kernel.make_splash_mha(
434+
mask=multi_head_mask,
435+
head_shards=1, # the sizes of the axis is sharding over heads
436+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
437+
block_sizes=block_sizes,
438+
save_residuals=True if "ring" in attention_kernel else False,
439+
residual_checkpoint_name=residual_checkpoint_name,
440+
)
441+
442+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
443+
444+
if not mask_padding_tokens:
445+
segment_ids = None
446+
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
447+
attention_output = vmapped_splash(query, key, value, segment_ids)
448+
else:
449+
if num_context_shards > 1:
450+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
451+
m = lse.astype(jnp.float32)
452+
l = jnp.exp(lse - m)
453+
o = out.astype(jnp.float32) * l[..., None]
454+
455+
perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)]
456+
457+
k1 = jax.lax.ppermute(key, axis_name="context", perm=perm)
458+
v1 = jax.lax.ppermute(value, axis_name="context", perm=perm)
459+
460+
def ring_scan_body(carry, _):
461+
m, l, o, k_current, v_current = carry
462+
k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm)
463+
v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm)
464+
465+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
466+
467+
m_chunk = lse_chunk.astype(jnp.float32)
468+
m_old = m
469+
m = jnp.maximum(m_old, m_chunk)
470+
471+
exp_m_diff = jnp.exp(m_old - m)
472+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
473+
474+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
475+
o = o * exp_m_diff[..., None]
476+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
477+
478+
# Return the updated state for the next iteration
479+
return (m, l, o, k_next, v_next), None
480+
481+
initial_carry = (m, l, o, k1, v1)
482+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(
483+
ring_scan_body, initial_carry, None, length=num_context_shards - 1
484+
)
485+
486+
attention_output = o_final / l_final[..., None]
487+
else:
488+
raise ValueError("ring attention requires context > 1")
489+
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
490+
491+
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
492+
# This warning might show up when doing model eval for example, when calculating model flops
493+
# and that is expected.
494+
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
495+
max_logging.log(
496+
"Warning, batch dimension should be shardable among the devices in data and fsdp"
497+
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
498+
)
499+
x = wrap_flash_attention(query, key, value)
500+
# Trim back to original sequence length after context-axis padding.
501+
x = x[:, :, :orig_q_seq_len, :]
350502
x = _reshape_heads_to_head_dim(x)
351503

352504
return x
@@ -429,6 +581,7 @@ def wrap_ulysses_attention(query, key, value):
429581
block_kv = max(*block_kv_sizes)
430582
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
431583
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
584+
432585
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
433586
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
434587

@@ -577,10 +730,6 @@ def wrap_ulysses_attention(query, key, value):
577730
return x
578731

579732

580-
def torchax_attention():
581-
pass
582-
583-
584733
def _apply_attention_dot(
585734
query: Array,
586735
key: Array,
@@ -692,7 +841,7 @@ def _apply_attention(
692841
axis_names_kv: AxisNames,
693842
flash_block_sizes: BlockSizes,
694843
dpa_layer: Callable,
695-
mask_padding_tokens: bool = False,
844+
mask_padding_tokens: bool = True,
696845
residual_checkpoint_name: str | None = None,
697846
attention_mask: Array = None,
698847
use_base2_exp: bool = False,
@@ -954,7 +1103,7 @@ def __init__(
9541103
flash_block_sizes: BlockSizes = None,
9551104
dtype: DType = jnp.float32,
9561105
quant: Quant = None,
957-
mask_padding_tokens: bool = False,
1106+
mask_padding_tokens: bool = True,
9581107
residual_checkpoint_name: str | None = None,
9591108
use_base2_exp: bool = False,
9601109
use_experimental_scheduler: bool = False,
@@ -1122,7 +1271,7 @@ def __init__(
11221271
qkv_bias: bool = False,
11231272
quant: Quant = None,
11241273
is_self_attention: bool = True,
1125-
mask_padding_tokens: bool = False,
1274+
mask_padding_tokens: bool = True,
11261275
residual_checkpoint_name: str | None = None,
11271276
enable_jax_named_scopes: bool = False,
11281277
added_kv_proj_dim: Optional[int] = None, # New for I2V

0 commit comments

Comments
 (0)