Skip to content

Commit 4e694ce

Browse files
committed
Add 2D context parallelism (Ulysses + Ring) for WAN attention
- Implement _2d_context_attention() combining Ulysses all-to-all with Ring ppermute - Add context_ulysses_parallelism and context_ring_parallelism config params - Add ulysses_ring attention type with automatic cross-attention fallback to flash - Resolve merge with main (VAE decode fix + base2_exp/experimental_scheduler)
1 parent c98002f commit 4e694ce

13 files changed

Lines changed: 1186 additions & 13 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,4 @@ gha-creds-*.json
184184

185185
# JAX cache
186186
.jax_cache/
187+
sdxl-model-finetuned/

BENCHMARK_RESULTS.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# 2D Context Parallelism Benchmark Results
2+
3+
## Setup
4+
- **Model**: WAN 2.2 T2V (Wan-AI/Wan2.2-T2V-A14B-Diffusers)
5+
- **Hardware**: TPU v7-8 (4 chips x 2 cores = 8 devices)
6+
- **Resolution**: 720x1280, 81 frames, 40 inference steps
7+
- **Parallelism**: DP=2, CP=4 (per_device_batch_size=0.125, 1 video)
8+
- **Flash block sizes**: block_q=2048, block_kv_compute=1024, block_kv=2048
9+
- **VAE spatial sharding**: 8
10+
- **Attention kernel**: tokamax splash (fused bwd kernel)
11+
12+
## Results
13+
14+
### Without XLA optimization flags
15+
16+
| Config | Compile (s) | Generation (s) | vs 2D Context |
17+
|--------|------------|----------------|---------------|
18+
| **2D Context (U=2, R=2)** | 228.8 | **208.0** | baseline |
19+
| Flash (CP=4) | 250.9 | **231.3** | +11.2% slower |
20+
| Ring (CP=4) | 258.1 | **237.1** | +14.0% slower |
21+
22+
### With XLA optimization flags (LIBTPU_INIT_ARGS)
23+
24+
| Config | Compile (s) | Generation (s) | vs 2D Context |
25+
|--------|------------|----------------|---------------|
26+
| **2D Context (U=2, R=2)** | 268.2 | **182.4** | baseline |
27+
| Flash (CP=4) | 231.8 | **205.3** | +12.6% slower |
28+
| Ring (CP=4) | -- | -- | not run |
29+
30+
### Key XLA flags used
31+
```
32+
--xla_tpu_enable_async_all_to_all=true
33+
--xla_enable_async_collective_permute=true
34+
--xla_tpu_enable_async_collective_fusion=true
35+
--xla_tpu_overlap_compute_collective_tc=true
36+
--xla_tpu_dvfs_p_state=7
37+
--xla_tpu_scoped_vmem_limit_kib=65536
38+
--xla_latency_hiding_scheduler_rerun=2
39+
```
40+
(plus additional scheduler/pipelining flags -- see benchmark script)
41+
42+
## Why 2D Context Parallelism Wins
43+
44+
TPU v7-8 topology: 4 chips with 2 cores each. Intra-chip bandwidth between
45+
cores is **6x faster** than inter-chip ICI.
46+
47+
With U=2, R=2:
48+
- **Ulysses all-to-all (U=2)**: Runs between the 2 cores on the same chip,
49+
using the 6x fast intra-chip link.
50+
- **Ring ppermute (R=2)**: Only 1 rotation step across chips (R-1=1), halving
51+
ICI communication vs pure ring (CP=4, 3 rotation steps).
52+
53+
The async collective flags further help by overlapping the all-to-all and
54+
ppermute communication with compute.
55+
56+
## Implementation
57+
58+
The 2D context parallelism is activated with:
59+
```yaml
60+
attention: 'ulysses_ring'
61+
ici_data_parallelism: 2
62+
ici_context_parallelism: 4
63+
context_ulysses_parallelism: 2
64+
context_ring_parallelism: 2
65+
```
66+
67+
Files changed (git diff from main):
68+
- `src/maxdiffusion/models/attention_flax.py` -- core `_2d_context_attention()` using tokamax splash
69+
- `src/maxdiffusion/models/wan/transformers/transformer_wan.py` -- plumbing U/R params
70+
- `src/maxdiffusion/pipelines/wan/wan_pipeline.py` -- config loading
71+
- `src/maxdiffusion/pyconfig.py` -- axis rules and validation
72+
- `src/maxdiffusion/common_types.py` -- ULYSSES_RING_ATTENTION_AXIS_RULES
73+
- `src/maxdiffusion/configs/base_wan_14b.yml` -- config params
74+
- `src/maxdiffusion/configs/base_wan_27b.yml` -- config params + block sizes

benchmark_2d_context.sh

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#!/bin/bash
2+
# Benchmark script for comparing attention parallelism strategies on v7-8 TPU
3+
# Runs training with synthetic data for a few steps and records step times.
4+
5+
set -e
6+
7+
REPO_DIR="/mnt/data/sagarchapara/workspace/maxdiffusion"
8+
VENV="/mnt/data/sagarchapara/workspace/venv"
9+
BASE_CONFIG="src/maxdiffusion/configs/base_wan_14b.yml"
10+
RESULTS_DIR="/mnt/data/sagarchapara/workspace/benchmark_results"
11+
METRICS_DIR="${RESULTS_DIR}/metrics"
12+
13+
export HF_HOME="/mnt/data/sagarchapara/cache/huggingface"
14+
export JAX_COMPILATION_CACHE_DIR="/mnt/data/sagarchapara/cache/jax"
15+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95
16+
17+
source "${VENV}/bin/activate"
18+
cd "${REPO_DIR}"
19+
20+
mkdir -p "${RESULTS_DIR}" "${METRICS_DIR}"
21+
22+
NUM_STEPS=10 # steps to run (first step includes compilation, skip it)
23+
24+
run_benchmark() {
25+
local name="$1"
26+
local attention="$2"
27+
local context_parallelism="$3"
28+
local ulysses_par="$4"
29+
local ring_par="$5"
30+
local metrics_file="${METRICS_DIR}/${name}.txt"
31+
local log_file="${RESULTS_DIR}/${name}.log"
32+
33+
echo "=========================================="
34+
echo "Running benchmark: ${name}"
35+
echo " attention=${attention}, ici_context_parallelism=${context_parallelism}"
36+
echo " context_ulysses_parallelism=${ulysses_par}, context_ring_parallelism=${ring_par}"
37+
echo "=========================================="
38+
39+
python src/maxdiffusion/train_wan.py \
40+
src/maxdiffusion/configs/base_wan_14b.yml \
41+
run_name="${name}" \
42+
attention="${attention}" \
43+
dataset_type="synthetic" \
44+
ici_data_parallelism=1 \
45+
ici_fsdp_parallelism=1 \
46+
ici_context_parallelism="${context_parallelism}" \
47+
ici_tensor_parallelism=1 \
48+
dcn_data_parallelism=1 \
49+
dcn_fsdp_parallelism=1 \
50+
dcn_context_parallelism=1 \
51+
dcn_tensor_parallelism=1 \
52+
context_ulysses_parallelism="${ulysses_par}" \
53+
context_ring_parallelism="${ring_par}" \
54+
max_train_steps="${NUM_STEPS}" \
55+
per_device_batch_size=1 \
56+
metrics_file="${metrics_file}" \
57+
write_metrics=True \
58+
enable_profiler=False \
59+
scan_layers=True \
60+
remat_policy="NONE" \
61+
checkpoint_every=-1 \
62+
save_final_checkpoint=False \
63+
skip_jax_distributed_system=False \
64+
base_output_directory="" \
65+
height=480 \
66+
width=832 \
67+
num_frames=81 \
68+
enable_ssim=False \
69+
2>&1 | tee "${log_file}"
70+
71+
echo ""
72+
echo "Benchmark ${name} complete. Log: ${log_file}"
73+
echo ""
74+
}
75+
76+
echo "============================================================"
77+
echo " 2D Context Parallelism Benchmark Suite"
78+
echo " TPU v7-8 (8 chips)"
79+
echo " ${NUM_STEPS} training steps per config (step 0 = compilation)"
80+
echo "============================================================"
81+
echo ""
82+
83+
# 1. Pure Ring attention (context_parallelism=8)
84+
run_benchmark "ring_cp8" "ring" 8 1 1
85+
86+
# 2. Pure Ulysses attention (context_parallelism=8)
87+
run_benchmark "ulysses_cp8" "ulysses" 8 1 1
88+
89+
# 3. 2D context: Ulysses=2, Ring=4
90+
run_benchmark "2d_u2_r4" "ulysses_ring" 8 2 4
91+
92+
# 4. 2D context: Ulysses=4, Ring=2
93+
run_benchmark "2d_u4_r2" "ulysses_ring" 8 4 2
94+
95+
echo ""
96+
echo "============================================================"
97+
echo " All benchmarks complete. Extracting results..."
98+
echo "============================================================"
99+
echo ""
100+
101+
# Extract step times from logs
102+
python3 - <<'PYEOF'
103+
import re
104+
import os
105+
import json
106+
107+
results_dir = "/mnt/data/sagarchapara/workspace/benchmark_results"
108+
metrics_dir = os.path.join(results_dir, "metrics")
109+
configs = ["ring_cp8", "ulysses_cp8", "2d_u2_r4", "2d_u4_r2"]
110+
111+
print("\n" + "=" * 70)
112+
print("BENCHMARK RESULTS SUMMARY")
113+
print("=" * 70)
114+
print(f"{'Config':<20} {'Avg Step (s)':<15} {'Min Step (s)':<15} {'TFLOPS/dev':<15}")
115+
print("-" * 70)
116+
117+
summary = {}
118+
for config_name in configs:
119+
metrics_file = os.path.join(metrics_dir, f"{config_name}.txt")
120+
if not os.path.exists(metrics_file):
121+
print(f"{config_name:<20} {'NO DATA':<15}")
122+
continue
123+
124+
step_times = []
125+
tflops_vals = []
126+
with open(metrics_file, "r") as f:
127+
for line in f:
128+
line = line.strip()
129+
if not line:
130+
continue
131+
try:
132+
data = json.loads(line)
133+
if "perf/step_time_seconds" in data.get("scalar", {}):
134+
step_times.append(data["scalar"]["perf/step_time_seconds"])
135+
if "perf/per_device_tflops_per_sec" in data.get("scalar", {}):
136+
tflops_vals.append(data["scalar"]["perf/per_device_tflops_per_sec"])
137+
except json.JSONDecodeError:
138+
# Try line-by-line key=value format
139+
pass
140+
141+
if not step_times:
142+
# Try parsing from log file
143+
log_file = os.path.join(results_dir, f"{config_name}.log")
144+
if os.path.exists(log_file):
145+
with open(log_file, "r") as f:
146+
for line in f:
147+
m = re.search(r"step_time_seconds['\"]?\s*[:=]\s*([0-9.]+)", line)
148+
if m:
149+
step_times.append(float(m.group(1)))
150+
m = re.search(r"per_device_tflops_per_sec['\"]?\s*[:=]\s*([0-9.]+)", line)
151+
if m:
152+
tflops_vals.append(float(m.group(1)))
153+
154+
if step_times:
155+
# Skip first step (compilation)
156+
warmup = step_times[:1]
157+
steady = step_times[1:] if len(step_times) > 1 else step_times
158+
avg_time = sum(steady) / len(steady)
159+
min_time = min(steady)
160+
avg_tflops = sum(tflops_vals[1:]) / len(tflops_vals[1:]) if len(tflops_vals) > 1 else (tflops_vals[0] if tflops_vals else 0)
161+
print(f"{config_name:<20} {avg_time:<15.4f} {min_time:<15.4f} {avg_tflops:<15.2f}")
162+
summary[config_name] = {"avg_step_time": avg_time, "min_step_time": min_time, "avg_tflops": avg_tflops, "warmup_time": warmup[0] if warmup else 0}
163+
else:
164+
print(f"{config_name:<20} {'PARSE ERROR':<15}")
165+
166+
print("-" * 70)
167+
if summary:
168+
best = min(summary.items(), key=lambda x: x[1]["avg_step_time"])
169+
print(f"\nBest config: {best[0]} with avg step time {best[1]['avg_step_time']:.4f}s")
170+
if "ring_cp8" in summary and "ulysses_cp8" in summary:
171+
ring_time = summary["ring_cp8"]["avg_step_time"]
172+
ulysses_time = summary["ulysses_cp8"]["avg_step_time"]
173+
for name, data in summary.items():
174+
if name.startswith("2d_"):
175+
speedup_vs_ring = (ring_time - data["avg_step_time"]) / ring_time * 100
176+
speedup_vs_ulysses = (ulysses_time - data["avg_step_time"]) / ulysses_time * 100
177+
print(f"{name}: {speedup_vs_ring:+.1f}% vs ring, {speedup_vs_ulysses:+.1f}% vs ulysses")
178+
179+
# Save summary
180+
with open(os.path.join(results_dir, "summary.json"), "w") as f:
181+
json.dump(summary, f, indent=2)
182+
183+
print(f"\nDetailed results saved to {results_dir}/summary.json")
184+
PYEOF

0 commit comments

Comments
 (0)