|
| 1 | +# Ulysses Sequence-Parallel Attention for MaxDiffusion |
| 2 | + |
| 3 | +## Introduction |
| 4 | + |
| 5 | +This doc proposes adding DeepSpeed Ulysses attention to the MaxDiffusion repo as a new sequence-parallel attention kernel. Ulysses uses all-to-all collectives to redistribute tensors across devices, enabling each device to compute full-sequence attention on a subset of heads. Benchmarks on Wan 2.1 I2V 14B (81 frames, 1280x720, 50 steps) show Ulysses completes inference in **282.8s** compared to flash at 307.1s and ring at 336.5s — a **24.3s (7.9%) wall-clock speedup** over flash and **53.7s (16%) speedup** over ring on 8 TPU chips. |
| 6 | + |
| 7 | +## Background / Motivation |
| 8 | + |
| 9 | +### The sequence length problem in video diffusion |
| 10 | + |
| 11 | +Video diffusion models produce long sequences. A single Wan 2.1 generation at 81 frames / 1280x720 yields sequence lengths in the tens of thousands of tokens. At these lengths, attention becomes the dominant compute and memory bottleneck. |
| 12 | + |
| 13 | +MaxDiffusion currently supports two approaches to distribute attention across devices: |
| 14 | + |
| 15 | +1. **Flash attention with sequence-parallel sharding** (`attention=flash`): When `ici_context_parallelism > 1`, Q is sharded across the `context` mesh axis while K/V are replicated across all devices. Each device computes attention on its Q shard against the full K/V. This reduces per-device memory for Q but still requires every device to hold the full K/V. Uses Pallas splash attention kernels on TPU. |
| 16 | + |
| 17 | +2. **Ring attention** (`attention=ring`): Both Q and K/V are sharded across the `context` mesh axis. Each device holds a shard of Q and K/V, then iteratively passes K/V blocks to neighbors via `ppermute`, accumulating partial softmax results with online log-sum-exp correction. This enables arbitrarily long sequences since no device needs to hold the full K/V, but introduces O(num_shards) sequential communication rounds. |
| 18 | + |
| 19 | +### Why Ulysses |
| 20 | + |
| 21 | +[Ulysses](https://arxiv.org/abs/2309.14509) uses **two all-to-all collectives** to trade the sequence dimension for the head dimension — each device gets the full sequence but a subset of heads, runs local attention, then all-to-all back. This replaces ring's O(N) sequential `ppermute` rounds with two bulk transfers. |
| 22 | + |
| 23 | +Ring can overlap communication with compute, which makes it attractive at large scale. But at smaller chip counts (4-8 devices), the per-round compute is too small to hide the latency, making ring communication-bound. Ulysses converts this into a single compute-bound local attention — which TPUs are optimized for. |
| 24 | + |
| 25 | +**This does not replace ring attention.** Ring remains better at large scale where overlap fully hides communication. Ulysses is complementary — a better fit at smaller scale (4-8 devices), the common setup for single-node inference and fine-tuning. [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/training/distributed_inference) found the same: Ulysses was ~2x faster than Ring on 4 H100 GPUs. They also support a **Unified Attention** hybrid (Ulysses + Ring in a 2D grid) for scaling beyond either alone — a potential future direction for MaxDiffusion. |
| 26 | + |
| 27 | +| Property | Flash (seq-parallel) | Ring | Ulysses | |
| 28 | +|---|---|---|---| |
| 29 | +| Sequence distribution | Q sharded, K/V replicated | Q and K/V sharded | Q and K/V sharded | |
| 30 | +| Communication pattern | None (replication cost) | O(N) sequential ppermute rounds (can overlap with compute) | 2 all-to-all collectives (not overlapped) | |
| 31 | +| Local attention | Q shard vs full K/V, all heads | Q shard vs K/V shard, all heads | Full sequence, head subset | |
| 32 | +| Bottleneck at small N | K/V memory per device | Communication-bound (overlap insufficient) | Compute-bound (favorable for TPU) | |
| 33 | +| Bottleneck at large N | K/V memory per device | Compute-bound (overlap hides comms) | Head dimension too small per device | |
| 34 | +| Constraint | K/V must fit per device | None | num_heads must be divisible by num_shards (padded if not) | |
| 35 | + |
| 36 | +## Design / Proposal |
| 37 | + |
| 38 | +### Ulysses Attention Formula |
| 39 | + |
| 40 | +Given N devices and input tensors Q, K, V sharded along the sequence dimension across N devices: |
| 41 | + |
| 42 | +``` |
| 43 | + ┌─────────────────────────────────────────────────────────┐ |
| 44 | + │ Ulysses Attention (per device i) │ |
| 45 | + │ │ |
| 46 | + │ Q_i, K_i, V_i ∈ R^{b × h × (s/N) × d} │ |
| 47 | + │ │ |
| 48 | + │ 1. Q', K', V' = AllToAll(Q_i, K_i, V_i) │ |
| 49 | + │ Q', K', V' ∈ R^{b × (h/N) × s × d} │ |
| 50 | + │ (trade: seq shards → head shards) │ |
| 51 | + │ │ |
| 52 | + │ 2. O' = SplashAttention(Q', K', V') │ |
| 53 | + │ O' ∈ R^{b × (h/N) × s × d} │ |
| 54 | + │ (standard attention on full seq, head subset) │ |
| 55 | + │ │ |
| 56 | + │ 3. O_i = AllToAll(O') │ |
| 57 | + │ O_i ∈ R^{b × h × (s/N) × d} │ |
| 58 | + │ (trade: head shards → seq shards) │ |
| 59 | + └─────────────────────────────────────────────────────────┘ |
| 60 | +``` |
| 61 | + |
| 62 | +The output is mathematically equivalent to standard multi-head attention — the all-to-all simply redistributes which device computes which heads. No numerical approximation or online softmax correction is needed. |
| 63 | + |
| 64 | +### High-level approach |
| 65 | + |
| 66 | +Ulysses is added as another attention mechanism alongside flash, ring, and dot_product. Users select it at runtime via `attention: ulysses` in the YAML config — no other config changes needed. |
| 67 | + |
| 68 | +**Key design decisions:** |
| 69 | +- **Reuse existing flash attention kernel**: The local attention inside Ulysses is the same Pallas splash attention kernel (`make_splash_mha`) already used by flash and ring. We add two `jax.lax.all_to_all` collectives before and after it — that's the core of the change. |
| 70 | +- **Standalone op, no changes to existing code**: Ulysses is a new function (`_ulysses_attention`) added next to the existing `_tpu_flash_attention`. Existing flash and ring paths are untouched. One `elif` branch is added to the router (`_apply_attention`) to dispatch to it. |
| 71 | +- **Same mesh axis as ring**: Uses the `context` axis for sequence parallelism. Any mesh config that works with ring works with Ulysses — just swap the attention type. |
| 72 | + |
| 73 | +### What changes |
| 74 | + |
| 75 | +| File | Change | Lines | |
| 76 | +|---|---|---| |
| 77 | +| `common_types.py` | Add `ULYSSES_ATTENTION_AXIS_RULES` (same sharding pattern as ring) | +12 | |
| 78 | +| `pyconfig.py` | Detect `attention == "ulysses"` and inject axis rules | +28 | |
| 79 | +| `attention_flax.py` | Add `_ulysses_attention()` function + routing branch | +175 | |
| 80 | +| **Total** | **3 files, 0 existing lines modified** | **+215** | |
| 81 | + |
| 82 | +### How it reuses existing code |
| 83 | + |
| 84 | +The implementation reuses existing utilities rather than reimplementing them: |
| 85 | +- `_reshape_data_for_flash()` — reshape and pad for context sharding |
| 86 | +- `_pad_data_for_flash()` — pad to flash block size |
| 87 | +- `splash_attention_kernel.make_splash_mha()` — the splash attention kernel itself |
| 88 | +- `_reshape_heads_to_head_dim()` — reshape output back |
| 89 | + |
| 90 | +### Config |
| 91 | + |
| 92 | +```yaml |
| 93 | +attention: ulysses # new option alongside flash, ring, dot_product |
| 94 | +ici_data_parallelism: 2 |
| 95 | +ici_context_parallelism: 4 # same setting as ring |
| 96 | +ici_tensor_parallelism: 1 |
| 97 | +``` |
| 98 | +
|
| 99 | +## Benchmark Results |
| 100 | +
|
| 101 | +**Model**: Wan 2.1 I2V 14B |
| 102 | +**Config**: 81 frames, 1280x720, 50 inference steps, per_device_batch_size=0.125 |
| 103 | +**Hardware**: 8 TPU chips, mesh = data:2 x context:4 |
| 104 | +
|
| 105 | +| Kernel | Compile Time | Inference Time | vs Flash | vs Ring | |
| 106 | +|---|---|---|---|---| |
| 107 | +| Flash | 494.6s | 307.1s | baseline | — | |
| 108 | +| Ring | 371.1s | 336.5s | +9.6% slower | baseline | |
| 109 | +| **Ulysses** | **319.6s** | **282.8s** | **-7.9% faster** | **-16.0% faster** | |
| 110 | +
|
| 111 | +### Why Ulysses is faster |
| 112 | +
|
| 113 | +1. **Communication efficiency**: Two all-to-all collectives vs. O(N-1) sequential ppermute rounds in ring. On TPU interconnect, all-to-all is a single bulk transfer that saturates bandwidth, while ppermute serializes transfers. |
| 114 | +
|
| 115 | +2. **No online softmax correction**: Ring attention must compute partial attention with log-sum-exp residuals (`save_residuals=True`) and accumulate across rounds with numerical correction (exp, max tracking). Ulysses sees the full sequence locally and runs standard attention (`save_residuals=False`) — no correction overhead. |
| 116 | + |
| 117 | +3. **Simpler compute graph**: Ulysses compiles to a simpler XLA graph (all-to-all → attention → all-to-all) vs. ring's `jax.lax.scan` loop with ppermute and online softmax accumulation. This also explains the faster compile time. |
| 118 | + |
| 119 | +4. **Compared to flash (no context parallelism)**: Flash without context parallelism requires each device to hold the full sequence length in memory. With context=4, Ulysses distributes the sequence across 4 devices and reassembles via all-to-all, which can be more efficient for the memory subsystem even though it adds communication. |
| 120 | + |
| 121 | +### Requirements |
| 122 | + |
| 123 | +- `ici_context_parallelism > 1` (otherwise Ulysses degenerates to standard flash with unnecessary overhead). |
| 124 | +- Number of attention heads should ideally be divisible by `ici_context_parallelism` to avoid head padding overhead. When not divisible, heads are padded before all-to-all and stripped after. |
| 125 | + |
| 126 | +## Open Questions |
| 127 | + |
| 128 | +1. **Training support**: This implementation has been benchmarked for inference only. Training with Ulysses should work (all-to-all is differentiable, splash attention backward is supported with `save_residuals=False`) but has not been validated end-to-end. Does the backward pass through the two all-to-all collectives produce correct gradients? |
| 129 | + |
| 130 | +2. **Cross-attention behavior**: In Wan 2.1, both self-attention and cross-attention route through the same `_apply_attention` dispatcher. Ulysses applies all-to-all to both. For cross-attention where KV sequence length is much shorter than Q, the all-to-all overhead may not be worth it. Should we fall back to flash for cross-attention? |
| 131 | + |
| 132 | +3. **Optimal context parallelism degree**: With 4 context shards, Ulysses is 7.9% faster than flash. How does this scale at 2 or 8 shards? At what point does all-to-all overhead exceed the benefit? |
| 133 | + |
| 134 | +4. **Interaction with tensor parallelism**: Currently tested with `ici_tensor_parallelism=1`. If tensor parallelism is also enabled (heads already split across tensor axis), the head split in Ulysses would interact with it. This combination has not been tested. |
| 135 | + |
| 136 | +5. **GQA / MQA support**: Current implementation pads and splits heads uniformly for Q, K, and V. For grouped-query attention where K/V have fewer heads than Q, the all-to-all split would need different handling per tensor. |
| 137 | + |
| 138 | +6. **Unified Attention (Ulysses + Ring hybrid)**: HuggingFace Diffusers supports a [Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) mode that combines Ulysses and Ring into a 2D grid — Ulysses parallelizes across heads while Ring parallelizes across sequence within each Ulysses group. This enables scaling beyond the head count limit. Should MaxDiffusion support a `ulysses_ring` or similar combined mode? |
0 commit comments