Skip to content

Commit 3304769

Browse files
committed
Adding torchax custom attention
1 parent ddbce4a commit 3304769

2 files changed

Lines changed: 684 additions & 725 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 125 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
import jax.numpy as jnp
24-
from jax.experimental import shard_map
2524
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2625
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2726
from maxdiffusion.kernels.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2827
from maxdiffusion.kernels.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
2928
from maxdiffusion.kernels.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
29+
from jax.experimental.shard_map import shard_map
3030
from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base
3131
from einops import rearrange
3232
from .. import common_types, max_logging
@@ -303,6 +303,7 @@ def convert_to_tokamax_splash_config(
303303
dq_reduction_steps=dq_reduction_steps,
304304
)
305305

306+
306307
def _tpu_flash_attention(
307308
query: jax.Array,
308309
key: jax.Array,
@@ -314,82 +315,88 @@ def _tpu_flash_attention(
314315
flash_block_sizes: BlockSizes,
315316
dtype: jnp.dtype = jnp.float32,
316317
attention_kernel: str = "flash",
317-
mask_padding_tokens: bool = False,
318+
mask_padding_tokens: bool = True,
318319
residual_checkpoint_name: str | None = None,
319320
attention_mask: jax.Array = None,
320321
use_base2_exp: bool = False,
321322
use_experimental_scheduler: bool = False,
322323
) -> jax.Array:
323-
"""Torchax-Style Tensor Parallel TPU Flash Attention"""
324-
325-
# 1. Reshape from (Batch, Seq_Len, Heads*Dim) -> (Batch, Heads, Seq_Len, Dim)
326-
# We pass num_context_shards=1 because we hold the full sequence locally.
327-
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards=1)
328-
key, _ = _reshape_data_for_flash(key, heads, num_context_shards=1)
329-
value, _ = _reshape_data_for_flash(value, heads, num_context_shards=1)
330-
331-
# 2. Call the built-in Torchax shard_map wrapper!
332-
# This automatically handles the SPMD boundaries, the min() padding logic for
333-
# cross-attention, and the axis swapping.
334-
x = custom_splash.tpu_custom_attention(
335-
query=query,
336-
key=key,
337-
value=value,
338-
mesh=mesh,
339-
# CRITICAL: JAX already scaled `key` by 1/sqrt(d) in the router.
340-
# Passing 1.0 ensures the kernel only applies the log2(e) multiplier to `query`.
341-
scale=1.0,
342-
block_q=2048,
343-
block_kv=2048,
344-
block_kv_compute=1024,
345-
block_kv_compute_in=256,
346-
heads_per_tile=1
324+
"""TPU Flash Attention"""
325+
num_context_shards = mesh.shape["context"]
326+
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
327+
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
328+
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
329+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
330+
331+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
332+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
333+
334+
@functools.partial(
335+
shard_map,
336+
mesh=mesh,
337+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
338+
out_specs=q_axis_names,
339+
check_rep=False,
340+
)
341+
def wrap_flash_attention(query, key, value):
342+
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
343+
block_q_sizes = (
344+
block_sizes.block_q,
345+
block_sizes.block_q_dkv,
347346
)
347+
block_kv_sizes = (
348+
block_sizes.block_kv,
349+
block_sizes.block_kv_dkv,
350+
)
351+
if uses_fused_kernel:
352+
block_q_sizes += (block_sizes.block_q_dkv,)
353+
block_kv_sizes += (block_sizes.block_kv_dkv,)
354+
else:
355+
block_q_sizes += (block_sizes.block_q_dq,)
356+
block_kv_sizes += (block_sizes.block_kv_dq,)
357+
block_q = max(*block_q_sizes)
358+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
359+
360+
block_kv = max(*block_kv_sizes)
361+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
362+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
363+
364+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
365+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
366+
367+
q_padded_len = query.shape[2]
368+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
369+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
370+
371+
kv_padded_len = key.shape[2]
372+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
373+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
374+
375+
if attention_mask is not None:
376+
mask_len = min(key_seq_len, attention_mask.shape[1])
377+
kv_mask_for_batch = attention_mask[0, :mask_len]
378+
if key_seq_len > mask_len:
379+
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
380+
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
381+
if kv_padded_len > key_seq_len:
382+
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
383+
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
384+
else:
385+
kv_mask_padded = kv_mask_for_batch
386+
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
387+
388+
if attention_kernel == "tokamax_ring":
389+
segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
390+
else:
391+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
348392

349-
# 3. Trim back to original sequence length
350-
x = x[:, :, :orig_q_seq_len, :]
351-
352-
# 4. Flatten back to (Batch, Seq_Len, Heads * Head_Dim)
353-
x = _reshape_heads_to_head_dim(x)
354-
355-
return x
356-
357-
358-
# def _tpu_flash_attention(
359-
# query: jax.Array,
360-
# key: jax.Array,
361-
# value: jax.Array,
362-
# heads: int,
363-
# mesh: Mesh,
364-
# axis_names_q: AxisNames,
365-
# axis_names_kv: AxisNames,
366-
# flash_block_sizes: BlockSizes,
367-
# dtype: jnp.dtype = jnp.float32,
368-
# attention_kernel: str = "flash",
369-
# mask_padding_tokens: bool = False,
370-
# residual_checkpoint_name: str | None = None,
371-
# attention_mask: jax.Array = None,
372-
# ) -> jax.Array:
373-
# """TPU Flash Attention"""
374-
375-
# num_context_shards = mesh.shape["context"]
376-
# query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
377-
# key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
378-
# value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
379-
# block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
380-
381-
# q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
382-
# kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
383-
384-
# make_splash_mha is wrapped around shardmap and seq and head is already
385-
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
386393
if attention_kernel == "tokamax_flash":
387394
mask = tokamax_splash_attention_mask.FullMask(
388395
_shape=(query.shape[2], key.shape[2]),
389396
)
390397
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
391398
mask=mask,
392-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
399+
q_seq_shards=1,
393400
config=convert_to_tokamax_splash_config(
394401
block_sizes,
395402
residual_checkpoint_name=residual_checkpoint_name,
@@ -413,7 +420,7 @@ def _tpu_flash_attention(
413420
),
414421
save_residuals=False,
415422
ring_axis="context",
416-
rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
423+
rotate_segment_ids=False,
417424
)
418425
else:
419426
splash_kernel = splash_attention_kernel.make_splash_mha(
@@ -425,134 +432,65 @@ def _tpu_flash_attention(
425432
residual_checkpoint_name=residual_checkpoint_name,
426433
)
427434

428-
# query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
429-
430-
# #block_kv = max(*block_kv_sizes)
431-
# key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
432-
# value, _, _ = _pad_data_for_flash(value, heads, bkv)
433-
434-
# bsizes = custom_splash._BlockSizes(
435-
# block_q=bq,
436-
# block_kv=bkv,
437-
# block_kv_compute=bkv_compute,
438-
# )
439-
440-
# # mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
441-
# # multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
442-
443-
# q_padded_len = query.shape[2]
444-
# q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
445-
# q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
446-
447-
# kv_padded_len = key.shape[2]
448-
# kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
449-
# kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
450-
451-
# # make_splash_mha is wrapped around shardmap and seq and head is already
452-
# # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
453-
# if attention_kernel == "tokamax_flash":
454-
# mask = tokamax_splash_attention_mask.FullMask(
455-
# _shape=(query.shape[2], key.shape[2]),
456-
# )
457-
# splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
458-
# mask=mask,
459-
# q_seq_shards=1, # the sizes of the axis is sharding over seq_len
460-
# config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
461-
# save_residuals=False,
462-
# )
463-
# elif attention_kernel == "tokamax_ring":
464-
# splash_kernel = custom_splash.make_splash_mha(
465-
# block_sizes=bsizes,
466-
# bkv_compute_in=bkv_compute_in,
467-
# orig_q_seq_len=query_seq_len,
468-
# orig_kv_seq_len=key_seq_len,
469-
# heads_per_tile=heads_per_tile
470-
# )
471-
# else:
472-
# splash_kernel = custom_splash.make_splash_mha(
473-
# block_sizes=bsizes,
474-
# bkv_compute_in=bkv_compute_in,
475-
# orig_q_seq_len=query_seq_len,
476-
# orig_kv_seq_len=key_seq_len,
477-
# heads_per_tile=heads_per_tile
478-
# )
479-
480-
# vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0), out_axes=0)
481-
482-
# if not mask_padding_tokens:
483-
# segment_ids = None
484-
# if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
485-
# query_scaled = query * 1.44269504
486-
# attention_output = vmapped_splash(query_scaled, key, value)
487-
# attention_output = jnp.swapaxes(attention_output, 2, 3)
488-
# else:
489-
# if num_context_shards > 1:
490-
# out, (lse,) = vmapped_splash(query, key, value, segment_ids)
491-
# m = lse.astype(jnp.float32)
492-
# l = jnp.exp(lse - m)
493-
# o = out.astype(jnp.float32) * l[..., None]
494-
495-
# perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)]
496-
497-
# k1 = jax.lax.ppermute(key, axis_name="context", perm=perm)
498-
# v1 = jax.lax.ppermute(value, axis_name="context", perm=perm)
499-
500-
# def ring_scan_body(carry, _):
501-
# m, l, o, k_current, v_current = carry
502-
# k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm)
503-
# v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm)
435+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
436+
437+
if not mask_padding_tokens:
438+
segment_ids = None
439+
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
440+
attention_output = vmapped_splash(query, key, value, segment_ids)
441+
else:
442+
if num_context_shards > 1:
443+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
444+
m = lse.astype(jnp.float32)
445+
l = jnp.exp(lse - m)
446+
o = out.astype(jnp.float32) * l[..., None]
447+
448+
perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)]
449+
450+
k1 = jax.lax.ppermute(key, axis_name="context", perm=perm)
451+
v1 = jax.lax.ppermute(value, axis_name="context", perm=perm)
452+
453+
def ring_scan_body(carry, _):
454+
m, l, o, k_current, v_current = carry
455+
k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm)
456+
v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm)
457+
458+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
459+
460+
m_chunk = lse_chunk.astype(jnp.float32)
461+
m_old = m
462+
m = jnp.maximum(m_old, m_chunk)
463+
464+
exp_m_diff = jnp.exp(m_old - m)
465+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
466+
467+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
468+
o = o * exp_m_diff[..., None]
469+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
470+
471+
return (m, l, o, k_next, v_next), None
472+
473+
initial_carry = (m, l, o, k1, v1)
474+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(
475+
ring_scan_body, initial_carry, None, length=num_context_shards - 1
476+
)
477+
478+
attention_output = o_final / l_final[..., None]
479+
else:
480+
raise ValueError("ring attention requires context > 1")
481+
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
504482

505483
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
506-
# This warning might show up when doing model eval for example, when calculating model flops
507-
# and that is expected.
508484
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
509485
max_logging.log(
510486
"Warning, batch dimension should be shardable among the devices in data and fsdp"
511487
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
512488
)
513489
x = wrap_flash_attention(query, key, value)
514-
# Trim back to original sequence length after context-axis padding.
515490
x = x[:, :, :orig_q_seq_len, :]
516491
x = _reshape_heads_to_head_dim(x)
517492

518-
# m_chunk = lse_chunk.astype(jnp.float32)
519-
# m_old = m
520-
# m = jnp.maximum(m_old, m_chunk)
521-
522-
# exp_m_diff = jnp.exp(m_old - m)
523-
# exp_m_chunk_diff = jnp.exp(m_chunk - m)
524-
525-
# l = l * exp_m_diff + jnp.exp(lse_chunk - m)
526-
# o = o * exp_m_diff[..., None]
527-
# o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
528-
529-
# # Return the updated state for the next iteration
530-
# return (m, l, o, k_next, v_next), None
531-
532-
# initial_carry = (m, l, o, k1, v1)
533-
# (m_final, l_final, o_final, _, _), _ = jax.lax.scan(
534-
# ring_scan_body, initial_carry, None, length=num_context_shards - 1
535-
# )
536-
537-
# attention_output = o_final / l_final[..., None]
538-
# else:
539-
# raise ValueError("ring attention requires context > 1")
540-
# return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
541-
542-
# devices_in_data_context = mesh.shape["data"] * mesh.shape["context"]
543-
# # This warning might show up when doing model eval for example, when calculating model flops
544-
# # and that is expected.
545-
# if not (query.shape[0] / devices_in_data_context).is_integer():
546-
# max_logging.log(
547-
# "Warning, batch dimension should be shardable among the devices in data and context"
548-
# f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
549-
# )
550-
# x = wrap_flash_attention(query, key, value)
551-
# # Trim back to original sequence length after context-axis padding.
552-
# x = x[:, :, :orig_q_seq_len, :]
553-
# x = _reshape_heads_to_head_dim(x)
554-
555-
# return x
493+
return x
556494

557495

558496
# ---------------------------------------------------------------------------
@@ -720,7 +658,7 @@ def _ulysses_custom_attention(
720658
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
721659
f"got heads={num_heads} and context_shards={num_shards}."
722660
)
723-
661+
724662
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
725663
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
726664

@@ -755,7 +693,7 @@ def wrap_ulysses_attention(query, key, value):
755693
bkv_compute_in=bkv_compute_in,
756694
orig_q_seq_len=query_seq_len,
757695
orig_kv_seq_len=key_seq_len,
758-
heads_per_tile=heads_per_tile
696+
heads_per_tile=heads_per_tile,
759697
)
760698

761699
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
@@ -779,11 +717,11 @@ def wrap_ulysses_attention(query, key, value):
779717

780718
return x
781719

782-
def torchax_attention(
783-
784-
):
720+
721+
def torchax_attention():
785722
pass
786723

724+
787725
def _apply_attention_dot(
788726
query: Array,
789727
key: Array,

0 commit comments

Comments
 (0)