Skip to content

Commit efbbdc8

Browse files
committed
LTX2 Performance enhancements
1 parent 46cae70 commit efbbdc8

51 files changed

Lines changed: 1344 additions & 189 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,6 @@ wandb
181181
# Gemini CLI
182182
.gemini/
183183
gha-creds-*.json
184+
185+
# JAX cache
186+
.jax_cache/

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,26 @@ To generate images, run the following command:
572572
* For Wan2.2 T2V, use `base_wan_27b.yml`.
573573
* For Wan2.2 I2V, use `base_wan_i2v_27b.yml`.
574574

575+
### Ulysses Attention
576+
577+
MaxDiffusion supports Ulysses attention for WAN TPU inference. Enable it by setting `attention="ulysses"`.
578+
579+
Internally, this follows the Ulysses sequence-parallel attention pattern and trades sequence shards for head shards around the local TPU splash kernel. For background, see [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509).
580+
581+
To enable Ulysses attention, set the corresponding override in your config YAML or pass it as a command-line override:
582+
583+
```bash
584+
python src/maxdiffusion/generate_wan.py \
585+
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
586+
attention="ulysses" \
587+
ici_context_parallelism=4 \
588+
...
589+
```
590+
591+
Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning.
592+
593+
In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup.
594+
575595
### Caching Mechanisms
576596

577597
Wan 2.x pipelines support several caching strategies to accelerate inference by skipping redundant transformer forward passes. These are **mutually exclusive** — enable only one at a time.
@@ -773,3 +793,5 @@ This script will automatically format your code with `pyink` and help you identi
773793
774794
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
775795
796+
## Profiling
797+
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).

docs/profiling.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# ML Diagnostics and Profiling
2+
3+
MaxDiffusion supports automated profiling and performance tracking via [Google Cloud ML Diagnostics](https://docs.cloud.google.com/tpu/docs/ml-diagnostics/sdk).
4+
5+
## 1. Manual Installation
6+
To keep the core MaxDiffusion repository lightweight and ensure it runs without dependencies for users who don't need profiling, the ML Diagnostics packages are **not** installed by default.
7+
8+
To use this feature, you must manually install the required package in your environment:
9+
```bash
10+
pip install google-cloud-mldiagnostics
11+
```
12+
13+
## 2. Configuration Settings
14+
To enable ML Diagnostics for your training or generation jobs, you need to update your configuration. You can either add these directly to your .yml config file or pass them as command-line arguments:
15+
16+
```yaml
17+
# ML Diagnostics settings
18+
enable_ml_diagnostics: True
19+
profiler_gcs_path: "gs://<your-bucket-name>/profiler/ml_diagnostics"
20+
enable_ondemand_xprof: True
21+
```
22+
23+
## 3. GCS Bucket Permissions (Troubleshooting)
24+
The GCS bucket you provide in `profiler_gcs_path` **must** have the correct IAM permissions to allow the Hypercompute Cluster service account to write data.
25+
26+
If permissions are not configured correctly, your job will fail with an error similar to this:
27+
> `message: 'service-32478767326@gcp-sa-hypercomputecluster.iam.gserviceaccount.com does not have storage.buckets.get access to the GCS bucket <your-bucket>: permission denied'`
28+
29+
**Fix:** Ensure you grant the required Storage roles (e.g., `Storage Object Admin`) to the service account mentioned in your error message for your specific GCS bucket.
30+
31+
## 4. Viewing Your Runs
32+
Once your job is running with diagnostics enabled, you can monitor the profiles, execution times, and metrics in the Cluster Director console here:
33+
34+
🔗 **https://pantheon.corp.google.com/cluster-director/diagnostics**

src/maxdiffusion/common_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@
8484
[CROSS_ATTN_Q_LENGTH, CONTEXT],
8585
[CROSS_ATTN_KV_LENGTH, None],
8686
]
87+
88+
### Common axis rules for ulysses attention ###
89+
ULYSSES_ATTENTION_AXIS_RULES = [
90+
[SELF_ATTN_HEAD, None],
91+
[SELF_ATTN_Q_LENGTH, CONTEXT],
92+
[SELF_ATTN_KV_LENGTH, CONTEXT],
93+
[CROSS_ATTN_HEAD, None],
94+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
95+
[CROSS_ATTN_KV_LENGTH, CONTEXT],
96+
]

src/maxdiffusion/configs/base14.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
206206
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
207207
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
208208
adam_weight_decay: 1.e-2 # AdamW Weight decay
209+
opt_enable_grad_clipping: False
210+
max_grad_value: 1.0
211+
opt_enable_grad_global_norm_clipping: False
209212
max_grad_norm: 1.0
210213

211214
enable_profiler: False
@@ -244,3 +247,8 @@ quantization: ''
244247
quantization_local_shard_count: -1
245248
use_qwix_quantization: False
246249
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
250+
251+
# ML Diagnostics settings
252+
enable_ml_diagnostics: False
253+
profiler_gcs_path: ""
254+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base21.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
211211
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
212212
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
213213
adam_weight_decay: 1.e-2 # AdamW Weight decay
214+
opt_enable_grad_clipping: False
215+
max_grad_value: 1.0
216+
opt_enable_grad_global_norm_clipping: False
214217
max_grad_norm: 1.0
215218

216219
enable_profiler: False
@@ -244,4 +247,9 @@ quantization: ''
244247
# Shard the range finding operation for quantization. By default this is set to number of slices.
245248
quantization_local_shard_count: -1
246249
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
247-
use_qwix_quantization: False
250+
use_qwix_quantization: False
251+
252+
# ML Diagnostics settings
253+
enable_ml_diagnostics: False
254+
profiler_gcs_path: ""
255+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
221221
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
222222
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
223223
adam_weight_decay: 1.e-2 # AdamW Weight decay
224+
opt_enable_grad_clipping: False
225+
max_grad_value: 1.0
226+
opt_enable_grad_global_norm_clipping: False
224227
max_grad_norm: 1.0
225228

226229
enable_profiler: False
@@ -260,3 +263,8 @@ quantization: ''
260263
quantization_local_shard_count: -1
261264
use_qwix_quantization: False
262265
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
266+
267+
# ML Diagnostics settings
268+
enable_ml_diagnostics: False
269+
profiler_gcs_path: ""
270+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
245245
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
246246
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
247247
adam_weight_decay: 0 # AdamW Weight decay
248+
opt_enable_grad_clipping: False
249+
max_grad_value: 1.0
250+
opt_enable_grad_global_norm_clipping: False
248251
max_grad_norm: 1.0
249252

250253
enable_profiler: False
@@ -303,3 +306,7 @@ quantization_local_shard_count: -1
303306
use_qwix_quantization: False
304307
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
305308

309+
# ML Diagnostics settings
310+
enable_ml_diagnostics: False
311+
profiler_gcs_path: ""
312+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
232232
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
233233
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
234234
adam_weight_decay: 1.e-2 # AdamW Weight decay
235+
opt_enable_grad_clipping: False
236+
max_grad_value: 1.0
237+
opt_enable_grad_global_norm_clipping: False
235238
max_grad_norm: 1.0
236239

237240
enable_profiler: False
@@ -288,3 +291,7 @@ quantization_local_shard_count: -1
288291
use_qwix_quantization: False
289292
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
290293

294+
# ML Diagnostics settings
295+
enable_ml_diagnostics: False
296+
profiler_gcs_path: ""
297+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
240240
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
241241
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
242242
adam_weight_decay: 1.e-2 # AdamW Weight decay
243+
opt_enable_grad_clipping: False
244+
max_grad_value: 1.0
245+
opt_enable_grad_global_norm_clipping: False
243246
max_grad_norm: 1.0
244247

245248
enable_profiler: False
@@ -297,4 +300,9 @@ quantization_local_shard_count: -1
297300
use_qwix_quantization: False
298301
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
299302

300-
save_final_checkpoint: False
303+
save_final_checkpoint: False
304+
305+
# ML Diagnostics settings
306+
enable_ml_diagnostics: False
307+
profiler_gcs_path: ""
308+
enable_ondemand_xprof: False

0 commit comments

Comments
 (0)