diff --git a/.gitignore b/.gitignore index d0f0ea7bd..9b0e1abfa 100644 --- a/.gitignore +++ b/.gitignore @@ -181,3 +181,6 @@ wandb # Gemini CLI .gemini/ gha-creds-*.json + +# JAX cache +.jax_cache/ diff --git a/README.md b/README.md index ac97f46c1..3f7ff4b1d 100755 --- a/README.md +++ b/README.md @@ -572,6 +572,26 @@ To generate images, run the following command: * For Wan2.2 T2V, use `base_wan_27b.yml`. * For Wan2.2 I2V, use `base_wan_i2v_27b.yml`. + ### Ulysses Attention + + MaxDiffusion supports Ulysses attention for WAN TPU inference. Enable it by setting `attention="ulysses"`. + + Internally, this follows the Ulysses sequence-parallel attention pattern and trades sequence shards for head shards around the local TPU splash kernel. For background, see [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509). + + To enable Ulysses attention, set the corresponding override in your config YAML or pass it as a command-line override: + + ```bash + python src/maxdiffusion/generate_wan.py \ + src/maxdiffusion/configs/base_wan_i2v_27b.yml \ + attention="ulysses" \ + ici_context_parallelism=4 \ + ... + ``` + + Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning. + + In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup. + ### Caching Mechanisms Wan 2.x pipelines support several caching strategies to accelerate inference by skipping redundant transformer forward passes. These are **mutually exclusive** — enable only one at a time. @@ -774,4 +794,4 @@ This script will automatically format your code with `pyink` and help you identi The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance. ## Profiling -To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md). \ No newline at end of file +To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md). diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 410ad8d28..feae6e933 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -84,3 +84,13 @@ [CROSS_ATTN_Q_LENGTH, CONTEXT], [CROSS_ATTN_KV_LENGTH, None], ] + +### Common axis rules for ulysses attention ### +ULYSSES_ATTENTION_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, CONTEXT], + [SELF_ATTN_KV_LENGTH, CONTEXT], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, CONTEXT], + [CROSS_ATTN_KV_LENGTH, CONTEXT], +] diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index fa036815b..e6b900657 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses flash_min_seq_length: 0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index a70abe95d..6e97d081e 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses flash_min_seq_length: 0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index caa5ab322..28a8c1cdf 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses flash_min_seq_length: 4096 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index fcbecfb66..f58ad8558 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index f5b8191c2..bc4b32f0d 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 9783d6480..583695fa7 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -443,6 +443,144 @@ def ring_scan_body(carry, _): return x +# --------------------------------------------------------------------------- +# Ulysses sequence-parallel attention +# --------------------------------------------------------------------------- + + +def _ulysses_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, + mask_padding_tokens: bool = True, + residual_checkpoint_name: str | None = None, + attention_mask: jax.Array = None, +) -> jax.Array: + """Ulysses sequence-parallel attention. + + Tensors arrive sequence-sharded on the context axis. Inside a shard_map the + all-to-all collectives trade sequence shards for head shards, run local + splash attention on the full sequence with a subset of heads, then all-to-all + back. + """ + axis_name = "context" + num_shards = mesh.shape[axis_name] + + # Reshape to [b, h, s, d] and pad sequence for even context-axis splitting. + query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards) + key, _ = _reshape_data_for_flash(key, heads, num_shards) + value, _ = _reshape_data_for_flash(value, heads, num_shards) + num_heads = query.shape[1] + # Ulysses only redistributes existing heads across the context mesh; unlike + # the earlier draft, we fail fast instead of padding synthetic heads. + if num_heads % num_shards != 0: + raise ValueError( + "Ulysses attention requires the number of heads to be divisible by the context shard count, " + f"got heads={num_heads} and context_shards={num_shards}." + ) + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") + + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + + @functools.partial( + jax.shard_map, + mesh=mesh, + in_specs=(q_axis_names, kv_axis_names, kv_axis_names), + out_specs=q_axis_names, + check_vma=False, + ) + def wrap_ulysses_attention(query, key, value): + # Swap sharding modes: each device gives up a slice of sequence and gathers + # a slice of heads, so the local splash kernel sees the full sequence. + query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) + key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) + value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) + + # Run the same local splash kernel as standard TPU flash attention, but now + # on full-sequence / fewer-heads tensors produced by the all-to-all above. + uses_fused_kernel = block_sizes.use_fused_bwd_kernel + block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv) + block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv) + if uses_fused_kernel: + block_q_sizes += (block_sizes.block_q_dkv,) + block_kv_sizes += (block_sizes.block_kv_dkv,) + else: + block_q_sizes += (block_sizes.block_q_dq,) + block_kv_sizes += (block_sizes.block_kv_dq,) + + block_q = max(*block_q_sizes) + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) + block_kv = max(*block_kv_sizes) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_kv) + + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + + q_padded_len = query.shape[2] + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) + + kv_padded_len = key.shape[2] + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + + # Reuse the standard flash-attention masking convention by zeroing invalid + # KV positions in the segment ids passed down to splash. + if attention_mask is not None: + mask_len = min(key_seq_len, attention_mask.shape[1]) + kv_mask_for_batch = attention_mask[0, :mask_len] + if key_seq_len > mask_len: + extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) + kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) + if kv_padded_len > key_seq_len: + padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) + kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) + else: + kv_mask_padded = kv_mask_for_batch + kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) + + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + if not mask_padding_tokens: + segment_ids = None + + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, + q_seq_shards=1, + block_sizes=block_sizes, + save_residuals=False, + residual_checkpoint_name=residual_checkpoint_name, + ) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + attention_output = vmapped_splash(query, key, value, segment_ids) + attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + + # Restore the original layout expected by the rest of the model: + # head-sharded / full-sequence -> sequence-sharded / full-heads. + attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + return attention_output + + devices_in_data_context = mesh.shape["data"] * num_shards + if not (query.shape[0] / devices_in_data_context).is_integer(): + max_logging.log( + "Warning, batch dimension should be shardable among the devices in data and context" + f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}" + ) + x = wrap_ulysses_attention(query, key, value) + x = x[:, :, :orig_q_seq_len, :] + x = _reshape_heads_to_head_dim(x) + + return x + + def _apply_attention_dot( query: Array, key: Array, @@ -563,7 +701,7 @@ def _apply_attention( seq_len_idx = 1 if query.ndim == 4: seq_len_idx = 2 - if attention_kernel in ["flash", "tokamax_flash"]: + if attention_kernel in ["flash", "tokamax_flash", "ulysses"]: can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length @@ -575,6 +713,21 @@ def _apply_attention( return _apply_attention_dot( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention ) + elif attention_kernel == "ulysses": + return _ulysses_attention( + query, + key * scale, + value, + heads, + mesh, + axis_names_q, + axis_names_kv, + flash_block_sizes, + dtype, + mask_padding_tokens=mask_padding_tokens, + residual_checkpoint_name=residual_checkpoint_name, + attention_mask=attention_mask, + ) elif attention_kernel in ["flash", "tokamax_flash"]: return _tpu_flash_attention( query, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index ebcdeaea3..6d2bb90c7 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,7 +27,17 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH -from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2, LTX2_VIDEO, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES +from maxdiffusion.common_types import ( + CONTEXT, + LENGTH, + KV_LENGTH, + WAN2_1, + WAN2_2, + LTX2_VIDEO, + RING_ATTENTION_AXIS_RULES, + SEQUENCE_PARALLEL_AXIS_RULES, + ULYSSES_ATTENTION_AXIS_RULES, +) _ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO} _ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1} @@ -200,25 +210,37 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. - if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]: + attention = raw_keys["attention"] + uses_ring_attention = attention == "ring" + uses_ulysses_attention = attention == "ulysses" + uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"] + if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding: max_logging.log( - f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set." + "Adding sequence sharding to q and kv if not already present because " + f"{attention=} requires it or attention_sharding_uniform={uses_uniform_sequence_sharding} is set." ) logical_axis_rules = list(raw_keys["logical_axis_rules"]) max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") new_rules = [] - q_seq_sharding = (LENGTH, "context") - kv_seq_sharding = (KV_LENGTH, "context") + q_seq_sharding = (LENGTH, CONTEXT) + kv_seq_sharding = (KV_LENGTH, CONTEXT) if q_seq_sharding not in logical_axis_rules: logical_axis_rules.append(q_seq_sharding) + max_logging.log(f"Adding sequence length axis rule {q_seq_sharding}") if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) - if raw_keys["attention"] == "ring": + max_logging.log(f"Adding key/value sequence axis rule {kv_seq_sharding}") + if uses_ring_attention: for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") new_rules.append(ring_attention_axis_rule) - else: # attention =flash but sequence parallel sharding requested for both self and cross attention + elif uses_ulysses_attention: + for ulysses_attention_axis_rule in ULYSSES_ATTENTION_AXIS_RULES: + if ulysses_attention_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding ulysses attention axis rule {ulysses_attention_axis_rule}") + new_rules.append(ulysses_attention_axis_rule) + else: # attention=flash but sequence parallel sharding requested for both self and cross attention for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES: if seq_parallel_axis_rule not in logical_axis_rules: max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}") diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 910b479a6..5c95dff8b 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -16,11 +16,16 @@ import os import unittest +from unittest import mock + from absl.testing import absltest +from flax.linen import partitioning as nn_partitioning import jax from jax.sharding import Mesh import jax.numpy as jnp from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel +import numpy as np +from ..models import attention_flax from ..models.attention_flax import FlaxAttention, _select_flash_block_sizes from .. import max_utils from .. import pyconfig @@ -34,6 +39,41 @@ class AttentionTest(unittest.TestCase): def setUp(self): AttentionTest.dummy_data = {} + def _ulysses_mesh(self): + devices = np.array(jax.devices()[:2]).reshape(1, 1, 2, 1) + return Mesh(devices, ("data", "fsdp", "context", "tensor")) + + def _ulysses_axis_rules(self): + return ( + (attention_flax.BATCH, "data"), + (attention_flax.SELF_ATTN_HEAD, None), + (attention_flax.SELF_ATTN_Q_LENGTH, "context"), + (attention_flax.SELF_ATTN_KV_LENGTH, "context"), + (attention_flax.D_KV, None), + ) + + def _flash_axis_rules(self): + return ( + (attention_flax.BATCH, "data"), + (attention_flax.SELF_ATTN_HEAD, None), + (attention_flax.SELF_ATTN_Q_LENGTH, "context"), + (attention_flax.SELF_ATTN_KV_LENGTH, None), + (attention_flax.D_KV, None), + ) + + def _ulysses_block_sizes(self, block_size=4): + return attention_flax.BlockSizes( + block_q=block_size, + block_kv_compute=block_size, + block_kv=block_size, + block_q_dkv=block_size, + block_kv_dkv=block_size, + block_kv_dkv_compute=block_size, + block_q_dq=block_size, + block_kv_dq=block_size, + use_fused_bwd_kernel=False, + ) + def test_splash_attention(self): """Test numerics of splash attention are equivalent to dot_product""" @@ -146,6 +186,261 @@ def test_default_flash_block_sizes_use_sequence_axis_for_3d_inputs(self): assert block_sizes.block_q_dq == 1024 assert block_sizes.block_kv_dq == 128 + def test_select_flash_block_sizes_returns_configured_for_self_attention(self): + """Block-size selection should return the configured sizes unchanged for self-attention.""" + custom_block_sizes = self._ulysses_block_sizes(block_size=16) + query = jnp.zeros((1, 128, 1), dtype=jnp.float32) + key = jnp.zeros((1, 128, 1), dtype=jnp.float32) + + self_attention_block_sizes = _select_flash_block_sizes( + query=query, + key=key, + flash_block_sizes=custom_block_sizes, + dtype=jnp.float32, + attention_kernel="flash", + ) + self.assertIs(self_attention_block_sizes, custom_block_sizes) + + def test_select_flash_block_sizes_derives_cross_attn_defaults_for_tokamax(self): + """Block-size selection should derive cross-attn defaults and set tokamax_flash flags.""" + custom_block_sizes = self._ulysses_block_sizes(block_size=16) + query = jnp.zeros((1, 257, 1), dtype=jnp.float32) + key = jnp.zeros((1, 513, 1), dtype=jnp.float32) + + cross_attention_block_sizes = _select_flash_block_sizes( + query=query, + key=key, + flash_block_sizes=custom_block_sizes, + dtype=jnp.float32, + attention_kernel="tokamax_flash", + ) + self.assertEqual(cross_attention_block_sizes.block_q, 16) + self.assertEqual(cross_attention_block_sizes.block_kv, 513) + self.assertEqual(cross_attention_block_sizes.block_kv_compute, 513) + self.assertEqual(cross_attention_block_sizes.block_kv_dkv_compute, 257) + self.assertIsNone(cross_attention_block_sizes.block_q_dq) + self.assertIsNone(cross_attention_block_sizes.block_kv_dq) + self.assertTrue(cross_attention_block_sizes.use_fused_bwd_kernel) + + def test_ulysses_attention_round_trips_query_when_heads_are_divisible(self): + """Ulysses attention should preserve the query layout after its collectives.""" + batch = 2 + length = 5 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_mesh() + + def fake_make_splash_mha(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, v, segment_ids + return q + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_axis_rules()), + mock.patch.object( + attention_flax.splash_attention_kernel, + "make_splash_mha", + side_effect=fake_make_splash_mha, + ), + ): + output = attention_flax._ulysses_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ) + + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, query)) + + def test_ulysses_attention_raises_when_heads_are_not_divisible_by_context_shards(self): + """Ulysses attention should fail fast when heads cannot be evenly sharded.""" + batch = 2 + length = 5 + heads = 3 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_axis_rules()): + with self.assertRaisesRegex( + ValueError, + r"heads=3 and context_shards=2", + ): + attention_flax._ulysses_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ) + + def test_ulysses_attention_matches_flash_attention_with_same_local_kernel(self): + """Flash and Ulysses should agree when the local splash kernel is shared.""" + batch = 2 + length = 6 + heads = 4 + head_depth = 3 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 100.0 + value = query + 200.0 + mesh = self._ulysses_mesh() + + def fake_make_splash_mha(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, segment_ids + return q + jnp.mean(v, axis=1, keepdims=True) + + return fake_kernel + + with mock.patch.object( + attention_flax.splash_attention_kernel, + "make_splash_mha", + side_effect=fake_make_splash_mha, + ): + with mesh, nn_partitioning.axis_rules(self._flash_axis_rules()): + flash_output = attention_flax._tpu_flash_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + attention_kernel="flash", + ) + + with mesh, nn_partitioning.axis_rules(self._ulysses_axis_rules()): + ulysses_output = attention_flax._ulysses_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ) + + self.assertEqual(flash_output.shape, ulysses_output.shape) + self.assertTrue(jnp.array_equal(flash_output, ulysses_output)) + + def test_ulysses_attention_uses_attention_mask_for_segment_ids(self): + """Ulysses attention should forward the attention mask into kv segment ids.""" + batch = 2 + length = 5 + heads = 4 + head_depth = 3 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 100.0 + value = query + 200.0 + attention_mask = jnp.array([[1, 0, 1, 0, 1]], dtype=jnp.int32) + mesh = self._ulysses_mesh() + + def fake_make_splash_mha(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, v + kv_mask = segment_ids.kv.astype(q.dtype)[None, :, None] + return q + kv_mask + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_axis_rules()), + mock.patch.object( + attention_flax.splash_attention_kernel, + "make_splash_mha", + side_effect=fake_make_splash_mha, + ), + ): + output = attention_flax._ulysses_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + attention_mask=attention_mask, + ) + + expected = query + jnp.broadcast_to(attention_mask[:, :, None], query.shape) + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, expected)) + if __name__ == "__main__": absltest.main()