From 86daf3267a010fc3cd0c035d1a2f6573bd381683 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 20 Oct 2025 21:53:42 +0000 Subject: [PATCH 1/4] adds wan2.1 training readme guide. --- README.md | 256 +++++++++++++++++++++- src/maxdiffusion/configs/base_wan_14b.yml | 4 +- 2 files changed, 257 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e6ed5f6ea..ad25415fa 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,261 @@ After installation completes, run the training script. ## Wan 2.1 Training - Coming soon. + in the first part, we'll run on a single host VM to get familiar with the workflow, then run on xpk for large scale training. + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + This workflow was tested using v5p-8 with a 500GB disk attached. + + ### Dataset Preparation + + For this example, we'll be using the [PusaV1 dataset](https://huggingface.co/datasets/RaphaelLiu/PusaV1_training). + + First, download the dataset. + + ```bash + export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/ + export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1 + huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR + ``` + + Next run the TFRecords conversion script. This step prepares training an eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality) + + Training dataset. + + ```bash + python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/train no_records_per_shard=10 enable_eval_timesteps=False + ``` + + The script will not have an output, but you can check the progress using: + + ```bash + ls -ll $TFRECORDS_DATASET_DIR/train + ``` + + Evaluation dataset. + + ```bash + python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/eval no_records_per_shard=10 enable_eval_timesteps=True + ``` + + The evaluation dataset creation takes the first 420 samples of the dataset and adds a timestep field. We then need to manually delete the first 420 samples from the `train` folder so they are not used in training. + + + ```bash + printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' rm + ``` + + And verify that they do not exist. + + ```bash + printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' echo + ``` + + After the script is done running, you should see the following directory structure inside `$TFRECORDS_DATASET_DIR` + + ``` + train + eval_timesteps + ``` + + In some instances an empty file `file_42-430.tfrec` is created inside `eval_timesteps`, for sanity check, let's run a delete command. + + ```bash + rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec + ``` + + ### Training on a Single VM + + Loading the data is supported both locally from the disk created above, or from `gcs`. In this guide, we'll be using a gcs bucket to train. First copy the data to the GCS bucket. + + ```bash + BUCKET_NAME=my-bucket + gsutil -m cp -r $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/} + ``` + + Now run the training command: + + ```bash + RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM} + OUTPUT_DIR=gs://$BUCKET_NAME/wan/ + DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/ + EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/ + SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/ + ``` + + ```bash + export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \ + --xla_tpu_megacore_fusion_allow_ags=false \ + --xla_enable_async_collective_permute=true \ + --xla_tpu_enable_ag_backward_pipelining=true \ + --xla_tpu_enable_data_parallel_all_reduce_opt=true \ + --xla_tpu_data_parallel_opt_different_sized_ops=true \ + --xla_tpu_enable_async_collective_fusion=true \ + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ + --xla_tpu_overlap_compute_collective_tc=true \ + --xla_enable_async_all_gather=true \ + --xla_tpu_scoped_vmem_limit_kib=81920 \ + --xla_tpu_enable_async_all_to_all=true \ + --xla_tpu_enable_all_experimental_scheduler_features=true \ + --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ + --xla_tpu_host_transfer_overlap_limit=24 \ + --xla_tpu_aggressive_opt_barrier_removal=ENABLED \ + --xla_lhs_prioritize_async_depth_over_stall=ENABLED \ + --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ + --xla_should_add_loop_invariant_op_in_chain=ENABLED \ + --xla_max_concurrent_host_send_recv=100 \ + --xla_tpu_scheduler_percent_shared_memory_limit=100 \ + --xla_latency_hiding_scheduler_rerun=2 \ + --xla_tpu_use_minor_sharding_for_major_trivial_input=true \ + --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \ + --xla_tpu_assign_all_reduce_scatter_layout=true' + ``` + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \ + src/maxdiffusion/configs/base_wan_14b.yml \ + attention='flash' \ + weights_dtype=bfloat16 \ + activations_dtype=bfloat16 \ + guidance_scale=5.0 \ + flow_shift=5.0 \ + fps=16 \ + skip_jax_distributed_system=False \ + run_name=${RUN_NAME} \ + output_dir=${OUTPUT_DIR} \ + train_data_dir=${DATASET_DIR} \ + load_tfrecord_cached=True \ + height=1280 \ + width=720 \ + num_frames=81 \ + num_inference_steps=50 \ + jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ + max_train_steps=1000 \ + enable_profiler=True \ + dataset_save_location=${SAVE_DATASET_DIR} \ + remat_policy='FULL' \ + flash_min_seq_length=0 \ + seed=$RANDOM \ + skip_first_n_steps_for_profiler=3 \ + profiler_steps=3 \ + per_device_batch_size=0.25 \ + ici_data_parallelism=1 \ + ici_fsdp_parallelism=4 \ + ici_tensor_parallelism=1 + ``` + + It is important to note a couple of things: + - per_device_batch_size can be a fractional, but must be a whole number when multiplied by number of devices. In this example, 0.25 * 4 (devices) = effective global batch size = 1. + - The step time in v5p-8 with global batch size = 1 is large due to using `FULL` remat. On larger number of chips we can run larger batch sizes greatly increasing MFU, as we will see in the next session of deploying with xpk. + - To enable eval during training set `eval_every` to a value > 0. + - In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism. + - You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism. + - For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now. + + You should eventually see a training run as: + + ```bash + ***** Running training ***** + Instantaneous batch size per device = 0.25 + Total train batch size (w. parallel & distributed) = 1 + Total optimization steps = 1000 + Calculated TFLOPs per pass: 4893.2719 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270 + To see full metrics 'tensorboard --logdir=gs://jfacevedo-maxdiffusion-v5p/wan/jfacevedo-wan-v5p-8-17263/tensorboard/' + completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144 + completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210 + completed step: 3, seconds: 36.016, TFLOP/s/device: 135.864, loss: 0.120 + completed step: 4, seconds: 36.008, TFLOP/s/device: 135.894, loss: 0.107 + completed step: 5, seconds: 36.008, TFLOP/s/device: 135.895, loss: 0.346 + completed step: 6, seconds: 36.006, TFLOP/s/device: 135.900, loss: 0.169 + ``` + + ### Deploying with XPK + + This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md). + + Using v5p-256 Then the command to run on xpk is as follows: + + ```bash + RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM} + OUTPUT_DIR=gs://$BUCKET_NAME/wan/ + DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/ + EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/ + SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/ + ``` + + ```bash + LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \ + --xla_tpu_megacore_fusion_allow_ags=false \ + --xla_enable_async_collective_permute=true \ + --xla_tpu_enable_ag_backward_pipelining=true \ + --xla_tpu_enable_data_parallel_all_reduce_opt=true \ + --xla_tpu_data_parallel_opt_different_sized_ops=true \ + --xla_tpu_enable_async_collective_fusion=true \ + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ + --xla_tpu_overlap_compute_collective_tc=true \ + --xla_enable_async_all_gather=true \ + --xla_tpu_scoped_vmem_limit_kib=81920 \ + --xla_tpu_enable_async_all_to_all=true \ + --xla_tpu_enable_all_experimental_scheduler_features=true \ + --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ + --xla_tpu_host_transfer_overlap_limit=24 \ + --xla_tpu_aggressive_opt_barrier_removal=ENABLED \ + --xla_lhs_prioritize_async_depth_over_stall=ENABLED \ + --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ + --xla_should_add_loop_invariant_op_in_chain=ENABLED \ + --xla_max_concurrent_host_send_recv=100 \ + --xla_tpu_scheduler_percent_shared_memory_limit=100 \ + --xla_latency_hiding_scheduler_rerun=2 \ + --xla_tpu_use_minor_sharding_for_major_trivial_input=true \ + --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \ + --xla_tpu_assign_all_reduce_scatter_layout=true' + ``` + + ```bash + python3 ~/xpk/xpk.py workload create \ + --cluster=$CLUSTER_NAME \ + --project=$PROJECT \ + --zone=$ZONE \ + --device-type=$DEVICE_TYPE \ + --num-slices=1 \ + --command=" \ + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \ + src/maxdiffusion/configs/base_wan_14b.yml \ + attention='flash' \ + weights_dtype=bfloat16 \ + activations_dtype=bfloat16 \ + guidance_scale=5.0 \ + flow_shift=5.0 \ + fps=16 \ + skip_jax_distributed_system=False \ + run_name=${RUN_NAME} \ + output_dir=${OUTPUT_DIR} \ + train_data_dir=${DATASET_DIR} \ + load_tfrecord_cached=True \ + height=1280 \ + width=720 \ + num_frames=81 \ + num_inference_steps=50 \ + jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ + max_train_steps=1000 \ + enable_profiler=True \ + dataset_save_location=${SAVE_DATASET_DIR} \ + remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ + flash_min_seq_length=0 \ + seed=$RANDOM \ + skip_first_n_steps_for_profiler=3 \ + profiler_steps=3 \ + per_device_batch_size=0.25 \ + ici_data_parallelism=64 \ + ici_fsdp_parallelism=2 \ + ici_tensor_parallelism=1" + ``` ## Flux Training diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 4a9730454..78a65377a 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -237,7 +237,7 @@ global_batch_size: 0 tfrecords_dir: '' no_records_per_shard: 0 enable_eval_timesteps: False -considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875] +timesteps_list: [125, 250, 375, 500, 625, 750, 875] num_eval_samples: 420 warmup_steps_fraction: 0.1 @@ -321,6 +321,6 @@ qwix_module_path: ".*" eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. -eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list). +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). enable_ssim: False \ No newline at end of file From a230776cf07ea11a4e32b0f575011937e71dd61a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 20 Oct 2025 22:07:44 +0000 Subject: [PATCH 2/4] update xpk command. --- README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ad25415fa..792bab689 100644 --- a/README.md +++ b/README.md @@ -342,7 +342,6 @@ After installation completes, run the training script. num_frames=81 \ num_inference_steps=50 \ jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ - max_train_steps=1000 \ enable_profiler=True \ dataset_save_location=${SAVE_DATASET_DIR} \ remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ @@ -353,7 +352,17 @@ After installation completes, run the training script. per_device_batch_size=0.25 \ ici_data_parallelism=64 \ ici_fsdp_parallelism=2 \ - ici_tensor_parallelism=1" + ici_tensor_parallelism=1" \ + max_train_steps=5000 \ + eval_every=100 \ + eval_data_dir=${EVAL_DATA_DIR} \ + enable_generate_video_for_eval=True \ + warmup_steps_fraction=0.025" + --base-docker-image=${IMAGE_DIR} \ + --enable-debug-logs \ + --workload=${RUN_NAME} \ + --priority=medium \ + --max-restarts=0 ``` ## Flux Training From ac519e67ddbbeb939fe558b5db56472a43d40645 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 20 Oct 2025 22:19:49 +0000 Subject: [PATCH 3/4] update xpk parallelism. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 792bab689..d83bbb4cf 100644 --- a/README.md +++ b/README.md @@ -350,8 +350,8 @@ After installation completes, run the training script. skip_first_n_steps_for_profiler=3 \ profiler_steps=3 \ per_device_batch_size=0.25 \ - ici_data_parallelism=64 \ - ici_fsdp_parallelism=2 \ + ici_data_parallelism=32 \ + ici_fsdp_parallelism=4 \ ici_tensor_parallelism=1" \ max_train_steps=5000 \ eval_every=100 \ From ba15d24e35c18cfff71f0195b4646fde8ae895f2 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 21 Oct 2025 00:25:52 +0000 Subject: [PATCH 4/4] resolve sanbao's comments. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d83bbb4cf..7345e0e50 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ After installation completes, run the training script. huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR ``` - Next run the TFRecords conversion script. This step prepares training an eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality) + Next run the TFRecords conversion script. This step prepares training and eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality) Training dataset. @@ -194,7 +194,7 @@ After installation completes, run the training script. --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ --xla_tpu_overlap_compute_collective_tc=true \ --xla_enable_async_all_gather=true \ - --xla_tpu_scoped_vmem_limit_kib=81920 \ + --xla_tpu_scoped_vmem_limit_kib=65536 \ --xla_tpu_enable_async_all_to_all=true \ --xla_tpu_enable_all_experimental_scheduler_features=true \ --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ @@ -299,7 +299,7 @@ After installation completes, run the training script. --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ --xla_tpu_overlap_compute_collective_tc=true \ --xla_enable_async_all_gather=true \ - --xla_tpu_scoped_vmem_limit_kib=81920 \ + --xla_tpu_scoped_vmem_limit_kib=65536 \ --xla_tpu_enable_async_all_to_all=true \ --xla_tpu_enable_all_experimental_scheduler_features=true \ --xla_tpu_enable_scheduler_memory_pressure_tracking=true \