Skip to content

Commit 4ba616c

Browse files
committed
Merge origin/main and resolve conflicts. Added Tokamax Ring Attention support to README.
2 parents f2a0e3c + 18f6f0f commit 4ba616c

43 files changed

Lines changed: 489 additions & 109 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2026/04/16`**: Support for Tokamax Ring Attention kernel is now added.
2021
- **`2026/03/31`**: Wan2.2 SenCache inference is now supported for T2V and I2V (up to 1.4x speedup)
2122
- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported
2223
- **`2026/03/25`**: LTX-2 Video Inference is now supported
@@ -535,6 +536,12 @@ To generate images, run the following command:
535536

536537
Supports both Text2Vid and Img2Vid pipelines.
537538

539+
**Note**: The product of per_device_batch_size and num_devices must be equal to a whole number.
540+
541+
The below command uses 4 devices and a per_device_batch_size=0.25. Thus, 4 * 0.25 = 1. This will generate a single video. Setting per_device_batch_size to 0.5, will generate 2 videos and so on.
542+
543+
If using 8 devices, then per_device_batch_size=0.125 will generate 1 video, per_device_batch_size=0.25 generates 2 videos.
544+
538545
The following command will run Wan2.1 T2V:
539546

540547
```bash
@@ -553,7 +560,7 @@ To generate images, run the following command:
553560
width=1280 \
554561
height=720 \
555562
jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \
556-
per_device_batch_size=.125 \
563+
per_device_batch_size=.0.25 \
557564
ici_data_parallelism=2 \
558565
ici_context_parallelism=2 \
559566
flow_shift=5.0 \
@@ -790,3 +797,6 @@ This script will automatically format your code with `pyink` and help you identi
790797
791798
792799
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
800+
801+
## Profiling
802+
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).

docs/profiling.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# ML Diagnostics and Profiling
2+
3+
MaxDiffusion supports automated profiling and performance tracking via [Google Cloud ML Diagnostics](https://docs.cloud.google.com/tpu/docs/ml-diagnostics/sdk).
4+
5+
## 1. Manual Installation
6+
To keep the core MaxDiffusion repository lightweight and ensure it runs without dependencies for users who don't need profiling, the ML Diagnostics packages are **not** installed by default.
7+
8+
To use this feature, you must manually install the required package in your environment:
9+
```bash
10+
pip install google-cloud-mldiagnostics
11+
```
12+
13+
## 2. Configuration Settings
14+
To enable ML Diagnostics for your training or generation jobs, you need to update your configuration. You can either add these directly to your .yml config file or pass them as command-line arguments:
15+
16+
```yaml
17+
# ML Diagnostics settings
18+
enable_ml_diagnostics: True
19+
profiler_gcs_path: "gs://<your-bucket-name>/profiler/ml_diagnostics"
20+
enable_ondemand_xprof: True
21+
```
22+
23+
## 3. GCS Bucket Permissions (Troubleshooting)
24+
The GCS bucket you provide in `profiler_gcs_path` **must** have the correct IAM permissions to allow the Hypercompute Cluster service account to write data.
25+
26+
If permissions are not configured correctly, your job will fail with an error similar to this:
27+
> `message: 'service-32478767326@gcp-sa-hypercomputecluster.iam.gserviceaccount.com does not have storage.buckets.get access to the GCS bucket <your-bucket>: permission denied'`
28+
29+
**Fix:** Ensure you grant the required Storage roles (e.g., `Storage Object Admin`) to the service account mentioned in your error message for your specific GCS bucket.
30+
31+
## 4. Viewing Your Runs
32+
Once your job is running with diagnostics enabled, you can monitor the profiles, execution times, and metrics in the Cluster Director console here:
33+
34+
🔗 **https://pantheon.corp.google.com/cluster-director/diagnostics**

src/maxdiffusion/configs/base14.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,8 @@ quantization: ''
247247
quantization_local_shard_count: -1
248248
use_qwix_quantization: False
249249
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
250+
251+
# ML Diagnostics settings
252+
enable_ml_diagnostics: False
253+
profiler_gcs_path: ""
254+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base21.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,4 +247,9 @@ quantization: ''
247247
# Shard the range finding operation for quantization. By default this is set to number of slices.
248248
quantization_local_shard_count: -1
249249
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
250-
use_qwix_quantization: False
250+
use_qwix_quantization: False
251+
252+
# ML Diagnostics settings
253+
enable_ml_diagnostics: False
254+
profiler_gcs_path: ""
255+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,8 @@ quantization: ''
263263
quantization_local_shard_count: -1
264264
use_qwix_quantization: False
265265
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
266+
267+
# ML Diagnostics settings
268+
enable_ml_diagnostics: False
269+
profiler_gcs_path: ""
270+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,7 @@ quantization_local_shard_count: -1
306306
use_qwix_quantization: False
307307
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
308308

309+
# ML Diagnostics settings
310+
enable_ml_diagnostics: False
311+
profiler_gcs_path: ""
312+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,7 @@ quantization_local_shard_count: -1
291291
use_qwix_quantization: False
292292
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
293293

294+
# ML Diagnostics settings
295+
enable_ml_diagnostics: False
296+
profiler_gcs_path: ""
297+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,4 +300,9 @@ quantization_local_shard_count: -1
300300
use_qwix_quantization: False
301301
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
302302

303-
save_final_checkpoint: False
303+
save_final_checkpoint: False
304+
305+
# ML Diagnostics settings
306+
enable_ml_diagnostics: False
307+
profiler_gcs_path: ""
308+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,4 +409,9 @@ eval_data_dir: ""
409409
enable_generate_video_for_eval: False # This will increase the used TPU memory.
410410
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
411411

412-
enable_ssim: False
412+
enable_ssim: False
413+
414+
# ML Diagnostics settings
415+
enable_ml_diagnostics: False
416+
profiler_gcs_path: ""
417+
enable_ondemand_xprof: False

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,8 @@ enable_generate_video_for_eval: False # This will increase the used TPU memory.
350350
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
351351

352352
enable_ssim: False
353+
354+
# ML Diagnostics settings
355+
enable_ml_diagnostics: False
356+
profiler_gcs_path: ""
357+
enable_ondemand_xprof: False

0 commit comments

Comments
 (0)