Skip to content

Commit 1a0f833

Browse files
committed
adding back comments
1 parent 3304769 commit 1a0f833

1 file changed

Lines changed: 27 additions & 168 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 27 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
import jax.numpy as jnp
2424
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2525
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
26-
from maxdiffusion.kernels.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2726
from maxdiffusion.kernels.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
28-
from maxdiffusion.kernels.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
29-
from jax.experimental.shard_map import shard_map
30-
from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base
3127
from einops import rearrange
3228
from .. import common_types, max_logging
3329

@@ -315,179 +311,42 @@ def _tpu_flash_attention(
315311
flash_block_sizes: BlockSizes,
316312
dtype: jnp.dtype = jnp.float32,
317313
attention_kernel: str = "flash",
318-
mask_padding_tokens: bool = True,
314+
mask_padding_tokens: bool = False,
319315
residual_checkpoint_name: str | None = None,
320316
attention_mask: jax.Array = None,
321317
use_base2_exp: bool = False,
322318
use_experimental_scheduler: bool = False,
323319
) -> jax.Array:
324-
"""TPU Flash Attention"""
325-
num_context_shards = mesh.shape["context"]
326-
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
327-
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
328-
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
329-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
330-
331-
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
332-
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
333-
334-
@functools.partial(
335-
shard_map,
320+
"""Torchax-Style Tensor Parallel TPU Flash Attention"""
321+
322+
# 1. Reshape from (Batch, Seq_Len, Heads*Dim) -> (Batch, Heads, Seq_Len, Dim)
323+
# We pass num_context_shards=1 because we hold the full sequence locally.
324+
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards=1)
325+
key, _ = _reshape_data_for_flash(key, heads, num_context_shards=1)
326+
value, _ = _reshape_data_for_flash(value, heads, num_context_shards=1)
327+
328+
# 2. Call the built-in Torchax shard_map wrapper!
329+
# This automatically handles the SPMD boundaries, the min() padding logic for
330+
# cross-attention, and the axis swapping.
331+
x = custom_splash.tpu_custom_attention(
332+
query=query,
333+
key=key,
334+
value=value,
336335
mesh=mesh,
337-
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
338-
out_specs=q_axis_names,
339-
check_rep=False,
336+
# CRITICAL: JAX already scaled `key` by 1/sqrt(d) in the router.
337+
# Passing 1.0 ensures the kernel only applies the log2(e) multiplier to `query`.
338+
scale=1.0,
339+
block_q=2048,
340+
block_kv=2048,
341+
block_kv_compute=1024,
342+
block_kv_compute_in=256,
343+
heads_per_tile=1,
340344
)
341-
def wrap_flash_attention(query, key, value):
342-
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
343-
block_q_sizes = (
344-
block_sizes.block_q,
345-
block_sizes.block_q_dkv,
346-
)
347-
block_kv_sizes = (
348-
block_sizes.block_kv,
349-
block_sizes.block_kv_dkv,
350-
)
351-
if uses_fused_kernel:
352-
block_q_sizes += (block_sizes.block_q_dkv,)
353-
block_kv_sizes += (block_sizes.block_kv_dkv,)
354-
else:
355-
block_q_sizes += (block_sizes.block_q_dq,)
356-
block_kv_sizes += (block_sizes.block_kv_dq,)
357-
block_q = max(*block_q_sizes)
358-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
359-
360-
block_kv = max(*block_kv_sizes)
361-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
362-
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
363-
364-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
365-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
366-
367-
q_padded_len = query.shape[2]
368-
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
369-
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
370-
371-
kv_padded_len = key.shape[2]
372-
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
373-
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
374-
375-
if attention_mask is not None:
376-
mask_len = min(key_seq_len, attention_mask.shape[1])
377-
kv_mask_for_batch = attention_mask[0, :mask_len]
378-
if key_seq_len > mask_len:
379-
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
380-
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
381-
if kv_padded_len > key_seq_len:
382-
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
383-
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
384-
else:
385-
kv_mask_padded = kv_mask_for_batch
386-
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
387-
388-
if attention_kernel == "tokamax_ring":
389-
segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
390-
else:
391-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
392-
393-
if attention_kernel == "tokamax_flash":
394-
mask = tokamax_splash_attention_mask.FullMask(
395-
_shape=(query.shape[2], key.shape[2]),
396-
)
397-
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
398-
mask=mask,
399-
q_seq_shards=1,
400-
config=convert_to_tokamax_splash_config(
401-
block_sizes,
402-
residual_checkpoint_name=residual_checkpoint_name,
403-
use_base2_exp=use_base2_exp,
404-
use_experimental_scheduler=use_experimental_scheduler,
405-
),
406-
save_residuals=False,
407-
)
408-
elif attention_kernel == "tokamax_ring":
409-
mask = tokamax_splash_attention_mask.FullMask(
410-
_shape=(query.shape[2], key.shape[2]),
411-
)
412-
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
413-
mask=mask,
414-
is_mqa=False,
415-
config=convert_to_tokamax_splash_config(
416-
block_sizes,
417-
residual_checkpoint_name=residual_checkpoint_name,
418-
use_base2_exp=use_base2_exp,
419-
use_experimental_scheduler=use_experimental_scheduler,
420-
),
421-
save_residuals=False,
422-
ring_axis="context",
423-
rotate_segment_ids=False,
424-
)
425-
else:
426-
splash_kernel = splash_attention_kernel.make_splash_mha(
427-
mask=multi_head_mask,
428-
head_shards=1, # the sizes of the axis is sharding over heads
429-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
430-
block_sizes=block_sizes,
431-
save_residuals=True if "ring" in attention_kernel else False,
432-
residual_checkpoint_name=residual_checkpoint_name,
433-
)
434-
435-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
436-
437-
if not mask_padding_tokens:
438-
segment_ids = None
439-
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
440-
attention_output = vmapped_splash(query, key, value, segment_ids)
441-
else:
442-
if num_context_shards > 1:
443-
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
444-
m = lse.astype(jnp.float32)
445-
l = jnp.exp(lse - m)
446-
o = out.astype(jnp.float32) * l[..., None]
447-
448-
perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)]
449-
450-
k1 = jax.lax.ppermute(key, axis_name="context", perm=perm)
451-
v1 = jax.lax.ppermute(value, axis_name="context", perm=perm)
452-
453-
def ring_scan_body(carry, _):
454-
m, l, o, k_current, v_current = carry
455-
k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm)
456-
v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm)
457-
458-
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
459345

460-
m_chunk = lse_chunk.astype(jnp.float32)
461-
m_old = m
462-
m = jnp.maximum(m_old, m_chunk)
463-
464-
exp_m_diff = jnp.exp(m_old - m)
465-
exp_m_chunk_diff = jnp.exp(m_chunk - m)
466-
467-
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
468-
o = o * exp_m_diff[..., None]
469-
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
470-
471-
return (m, l, o, k_next, v_next), None
472-
473-
initial_carry = (m, l, o, k1, v1)
474-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(
475-
ring_scan_body, initial_carry, None, length=num_context_shards - 1
476-
)
477-
478-
attention_output = o_final / l_final[..., None]
479-
else:
480-
raise ValueError("ring attention requires context > 1")
481-
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
482-
483-
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
484-
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
485-
max_logging.log(
486-
"Warning, batch dimension should be shardable among the devices in data and fsdp"
487-
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
488-
)
489-
x = wrap_flash_attention(query, key, value)
346+
# 3. Trim back to original sequence length
490347
x = x[:, :, :orig_q_seq_len, :]
348+
349+
# 4. Flatten back to (Batch, Seq_Len, Heads * Head_Dim)
491350
x = _reshape_heads_to_head_dim(x)
492351

493352
return x

0 commit comments

Comments
 (0)