Skip to content

Commit adde49d

Browse files
committed
working attention code
1 parent ceca471 commit adde49d

5 files changed

Lines changed: 427 additions & 10 deletions

File tree

benchmark_attention.sh

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#!/bin/bash
2+
# Benchmark: flash vs ring vs ulysses attention
3+
# Usage: bash benchmark_attention.sh [flash|ring|ulysses|all]
4+
#
5+
# All runs use the same mesh layout: data=2, context=4.
6+
# Only the attention kernel changes.
7+
8+
set -euo pipefail
9+
10+
# Activate virtualenv and set paths
11+
source /data/maxdiffusion-work/maxdiffusion-venv/bin/activate
12+
export PYTHONPATH=/data/maxdiffusion-work/maxdiffusion/src
13+
export HF_HOME=/data/maxdiffusion-work/hf-home
14+
15+
CONFIG="src/maxdiffusion/configs/base_wan_i2v_14b.yml"
16+
JAX_CACHE_DIR="${JAX_CACHE_DIR:-/tmp/jax_cache}"
17+
LOG_DIR="/data/maxdiffusion-work/bench_logs"
18+
mkdir -p "$LOG_DIR"
19+
20+
COMMON_ARGS=(
21+
"$CONFIG"
22+
num_inference_steps=50
23+
num_frames=81
24+
width=1280
25+
height=720
26+
ici_data_parallelism=2
27+
ici_context_parallelism=4
28+
ici_tensor_parallelism=1
29+
jax_cache_dir="$JAX_CACHE_DIR"
30+
per_device_batch_size=0.125
31+
allow_split_physical_axes=True
32+
"flash_block_sizes={\"block_q\":3024,\"block_kv_compute\":1024,\"block_kv\":2048,\"block_q_dkv\":3024,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":1024,\"use_fused_bwd_kernel\":true}"
33+
)
34+
35+
run_variant() {
36+
local name="$1"
37+
local attention="$2"
38+
echo "========================================="
39+
echo " Running: ${name} attention"
40+
echo "========================================="
41+
python src/maxdiffusion/generate_wan.py \
42+
"${COMMON_ARGS[@]}" \
43+
attention="$attention" \
44+
run_name="bench-${name}" \
45+
2>&1 | tee "${LOG_DIR}/bench_${name}.log"
46+
echo ""
47+
echo "${name} attention done. Log: ${LOG_DIR}/bench_${name}.log"
48+
}
49+
50+
MODE="${1:-all}"
51+
52+
case "$MODE" in
53+
flash) run_variant flash flash ;;
54+
ring) run_variant ring ring ;;
55+
ulysses) run_variant ulysses ulysses ;;
56+
all)
57+
run_variant flash flash
58+
echo ""
59+
run_variant ring ring
60+
echo ""
61+
run_variant ulysses ulysses
62+
echo ""
63+
echo "========================================="
64+
echo " All runs complete. Logs:"
65+
echo " ${LOG_DIR}/bench_flash.log"
66+
echo " ${LOG_DIR}/bench_ring.log"
67+
echo " ${LOG_DIR}/bench_ulysses.log"
68+
echo "========================================="
69+
;;
70+
*)
71+
echo "Usage: $0 [flash|ring|ulysses|all]"
72+
exit 1
73+
;;
74+
esac

docs/ulysses_attention_design.md

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Ulysses Sequence-Parallel Attention for MaxDiffusion
2+
3+
## Introduction
4+
5+
This doc proposes adding DeepSpeed Ulysses attention to the MaxDiffusion repo as a new sequence-parallel attention kernel. Ulysses uses all-to-all collectives to redistribute tensors across devices, enabling each device to compute full-sequence attention on a subset of heads. Benchmarks on Wan 2.1 I2V 14B (81 frames, 1280x720, 50 steps) show Ulysses completes inference in **282.8s** compared to flash at 307.1s and ring at 336.5s — a **24.3s (7.9%) wall-clock speedup** over flash and **53.7s (16%) speedup** over ring on 8 TPU chips.
6+
7+
## Background / Motivation
8+
9+
### The sequence length problem in video diffusion
10+
11+
Video diffusion models produce long sequences. A single Wan 2.1 generation at 81 frames / 1280x720 yields sequence lengths in the tens of thousands of tokens. At these lengths, attention becomes the dominant compute and memory bottleneck.
12+
13+
MaxDiffusion currently supports two approaches to distribute attention across devices:
14+
15+
1. **Flash attention with sequence-parallel sharding** (`attention=flash`): When `ici_context_parallelism > 1`, Q is sharded across the `context` mesh axis while K/V are replicated across all devices. Each device computes attention on its Q shard against the full K/V. This reduces per-device memory for Q but still requires every device to hold the full K/V. Uses Pallas splash attention kernels on TPU.
16+
17+
2. **Ring attention** (`attention=ring`): Both Q and K/V are sharded across the `context` mesh axis. Each device holds a shard of Q and K/V, then iteratively passes K/V blocks to neighbors via `ppermute`, accumulating partial softmax results with online log-sum-exp correction. This enables arbitrarily long sequences since no device needs to hold the full K/V, but introduces O(num_shards) sequential communication rounds.
18+
19+
### Why Ulysses
20+
21+
[Ulysses](https://arxiv.org/abs/2309.14509) uses **two all-to-all collectives** to trade the sequence dimension for the head dimension — each device gets the full sequence but a subset of heads, runs local attention, then all-to-all back. This replaces ring's O(N) sequential `ppermute` rounds with two bulk transfers.
22+
23+
Ring can overlap communication with compute, which makes it attractive at large scale. But at smaller chip counts (4-8 devices), the per-round compute is too small to hide the latency, making ring communication-bound. Ulysses converts this into a single compute-bound local attention — which TPUs are optimized for.
24+
25+
**This does not replace ring attention.** Ring remains better at large scale where overlap fully hides communication. Ulysses is complementary — a better fit at smaller scale (4-8 devices), the common setup for single-node inference and fine-tuning. [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/training/distributed_inference) found the same: Ulysses was ~2x faster than Ring on 4 H100 GPUs. They also support a **Unified Attention** hybrid (Ulysses + Ring in a 2D grid) for scaling beyond either alone — a potential future direction for MaxDiffusion.
26+
27+
| Property | Flash (seq-parallel) | Ring | Ulysses |
28+
|---|---|---|---|
29+
| Sequence distribution | Q sharded, K/V replicated | Q and K/V sharded | Q and K/V sharded |
30+
| Communication pattern | None (replication cost) | O(N) sequential ppermute rounds (can overlap with compute) | 2 all-to-all collectives (not overlapped) |
31+
| Local attention | Q shard vs full K/V, all heads | Q shard vs K/V shard, all heads | Full sequence, head subset |
32+
| Bottleneck at small N | K/V memory per device | Communication-bound (overlap insufficient) | Compute-bound (favorable for TPU) |
33+
| Bottleneck at large N | K/V memory per device | Compute-bound (overlap hides comms) | Head dimension too small per device |
34+
| Constraint | K/V must fit per device | None | num_heads must be divisible by num_shards (padded if not) |
35+
36+
## Design / Proposal
37+
38+
### Ulysses Attention Formula
39+
40+
Given N devices and input tensors Q, K, V sharded along the sequence dimension across N devices:
41+
42+
```
43+
┌─────────────────────────────────────────────────────────┐
44+
│ Ulysses Attention (per device i) │
45+
│ │
46+
│ Q_i, K_i, V_i ∈ R^{b × h × (s/N) × d} │
47+
│ │
48+
│ 1. Q', K', V' = AllToAll(Q_i, K_i, V_i) │
49+
│ Q', K', V' ∈ R^{b × (h/N) × s × d} │
50+
│ (trade: seq shards → head shards) │
51+
│ │
52+
│ 2. O' = SplashAttention(Q', K', V') │
53+
│ O' ∈ R^{b × (h/N) × s × d} │
54+
│ (standard attention on full seq, head subset) │
55+
│ │
56+
│ 3. O_i = AllToAll(O') │
57+
│ O_i ∈ R^{b × h × (s/N) × d} │
58+
│ (trade: head shards → seq shards) │
59+
└─────────────────────────────────────────────────────────┘
60+
```
61+
62+
The output is mathematically equivalent to standard multi-head attention — the all-to-all simply redistributes which device computes which heads. No numerical approximation or online softmax correction is needed.
63+
64+
### High-level approach
65+
66+
Ulysses is added as another attention mechanism alongside flash, ring, and dot_product. Users select it at runtime via `attention: ulysses` in the YAML config — no other config changes needed.
67+
68+
**Key design decisions:**
69+
- **Reuse existing flash attention kernel**: The local attention inside Ulysses is the same Pallas splash attention kernel (`make_splash_mha`) already used by flash and ring. We add two `jax.lax.all_to_all` collectives before and after it — that's the core of the change.
70+
- **Standalone op, no changes to existing code**: Ulysses is a new function (`_ulysses_attention`) added next to the existing `_tpu_flash_attention`. Existing flash and ring paths are untouched. One `elif` branch is added to the router (`_apply_attention`) to dispatch to it.
71+
- **Same mesh axis as ring**: Uses the `context` axis for sequence parallelism. Any mesh config that works with ring works with Ulysses — just swap the attention type.
72+
73+
### What changes
74+
75+
| File | Change | Lines |
76+
|---|---|---|
77+
| `common_types.py` | Add `ULYSSES_ATTENTION_AXIS_RULES` (same sharding pattern as ring) | +12 |
78+
| `pyconfig.py` | Detect `attention == "ulysses"` and inject axis rules | +28 |
79+
| `attention_flax.py` | Add `_ulysses_attention()` function + routing branch | +175 |
80+
| **Total** | **3 files, 0 existing lines modified** | **+215** |
81+
82+
### How it reuses existing code
83+
84+
The implementation reuses existing utilities rather than reimplementing them:
85+
- `_reshape_data_for_flash()` — reshape and pad for context sharding
86+
- `_pad_data_for_flash()` — pad to flash block size
87+
- `splash_attention_kernel.make_splash_mha()` — the splash attention kernel itself
88+
- `_reshape_heads_to_head_dim()` — reshape output back
89+
90+
### Config
91+
92+
```yaml
93+
attention: ulysses # new option alongside flash, ring, dot_product
94+
ici_data_parallelism: 2
95+
ici_context_parallelism: 4 # same setting as ring
96+
ici_tensor_parallelism: 1
97+
```
98+
99+
## Benchmark Results
100+
101+
**Model**: Wan 2.1 I2V 14B
102+
**Config**: 81 frames, 1280x720, 50 inference steps, per_device_batch_size=0.125
103+
**Hardware**: 8 TPU chips, mesh = data:2 x context:4
104+
105+
| Kernel | Compile Time | Inference Time | vs Flash | vs Ring |
106+
|---|---|---|---|---|
107+
| Flash | 494.6s | 307.1s | baseline | — |
108+
| Ring | 371.1s | 336.5s | +9.6% slower | baseline |
109+
| **Ulysses** | **319.6s** | **282.8s** | **-7.9% faster** | **-16.0% faster** |
110+
111+
### Why Ulysses is faster
112+
113+
1. **Communication efficiency**: Two all-to-all collectives vs. O(N-1) sequential ppermute rounds in ring. On TPU interconnect, all-to-all is a single bulk transfer that saturates bandwidth, while ppermute serializes transfers.
114+
115+
2. **No online softmax correction**: Ring attention must compute partial attention with log-sum-exp residuals (`save_residuals=True`) and accumulate across rounds with numerical correction (exp, max tracking). Ulysses sees the full sequence locally and runs standard attention (`save_residuals=False`) — no correction overhead.
116+
117+
3. **Simpler compute graph**: Ulysses compiles to a simpler XLA graph (all-to-all → attention → all-to-all) vs. ring's `jax.lax.scan` loop with ppermute and online softmax accumulation. This also explains the faster compile time.
118+
119+
4. **Compared to flash (no context parallelism)**: Flash without context parallelism requires each device to hold the full sequence length in memory. With context=4, Ulysses distributes the sequence across 4 devices and reassembles via all-to-all, which can be more efficient for the memory subsystem even though it adds communication.
120+
121+
### Requirements
122+
123+
- `ici_context_parallelism > 1` (otherwise Ulysses degenerates to standard flash with unnecessary overhead).
124+
- Number of attention heads should ideally be divisible by `ici_context_parallelism` to avoid head padding overhead. When not divisible, heads are padded before all-to-all and stripped after.
125+
126+
## Open Questions
127+
128+
1. **Training support**: This implementation has been benchmarked for inference only. Training with Ulysses should work (all-to-all is differentiable, splash attention backward is supported with `save_residuals=False`) but has not been validated end-to-end. Does the backward pass through the two all-to-all collectives produce correct gradients?
129+
130+
2. **Cross-attention behavior**: In Wan 2.1, both self-attention and cross-attention route through the same `_apply_attention` dispatcher. Ulysses applies all-to-all to both. For cross-attention where KV sequence length is much shorter than Q, the all-to-all overhead may not be worth it. Should we fall back to flash for cross-attention?
131+
132+
3. **Optimal context parallelism degree**: With 4 context shards, Ulysses is 7.9% faster than flash. How does this scale at 2 or 8 shards? At what point does all-to-all overhead exceed the benefit?
133+
134+
4. **Interaction with tensor parallelism**: Currently tested with `ici_tensor_parallelism=1`. If tensor parallelism is also enabled (heads already split across tensor axis), the head split in Ulysses would interact with it. This combination has not been tested.
135+
136+
5. **GQA / MQA support**: Current implementation pads and splits heads uniformly for Q, K, and V. For grouped-query attention where K/V have fewer heads than Q, the all-to-all split would need different handling per tensor.
137+
138+
6. **Unified Attention (Ulysses + Ring hybrid)**: HuggingFace Diffusers supports a [Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) mode that combines Ulysses and Ring into a 2D grid — Ulysses parallelizes across heads while Ring parallelizes across sequence within each Ulysses group. This enables scaling beyond the head count limit. Should MaxDiffusion support a `ulysses_ring` or similar combined mode?

src/maxdiffusion/common_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,15 @@
8484
[CROSS_ATTN_Q_LENGTH, CONTEXT],
8585
[CROSS_ATTN_KV_LENGTH, None],
8686
]
87+
88+
### Common axis rules for ulysses attention ###
89+
# Ulysses shards sequence on context axis (like ring), but uses all-to-all
90+
# collectives to trade sequence shards for head shards before local attention.
91+
ULYSSES_ATTENTION_AXIS_RULES = [
92+
[SELF_ATTN_HEAD, None],
93+
[SELF_ATTN_Q_LENGTH, CONTEXT],
94+
[SELF_ATTN_KV_LENGTH, CONTEXT],
95+
[CROSS_ATTN_HEAD, None],
96+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
97+
[CROSS_ATTN_KV_LENGTH, CONTEXT],
98+
]

0 commit comments

Comments
 (0)