Skip to content

Commit ca9a855

Browse files
authored
Merge branch 'main' into optimizer-resume
2 parents e8d8ccf + 662d501 commit ca9a855

8 files changed

Lines changed: 317 additions & 14 deletions

File tree

README.md

Lines changed: 263 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,269 @@ After installation completes, run the training script.
100100

101101
## Wan 2.1 Training
102102

103-
Coming soon.
103+
in the first part, we'll run on a single host VM to get familiar with the workflow, then run on xpk for large scale training.
104+
105+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
106+
107+
This workflow was tested using v5p-8 with a 500GB disk attached.
108+
109+
### Dataset Preparation
110+
111+
For this example, we'll be using the [PusaV1 dataset](https://huggingface.co/datasets/RaphaelLiu/PusaV1_training).
112+
113+
First, download the dataset.
114+
115+
```bash
116+
export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/
117+
export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1
118+
huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR
119+
```
120+
121+
Next run the TFRecords conversion script. This step prepares training and eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality)
122+
123+
Training dataset.
124+
125+
```bash
126+
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/train no_records_per_shard=10 enable_eval_timesteps=False
127+
```
128+
129+
The script will not have an output, but you can check the progress using:
130+
131+
```bash
132+
ls -ll $TFRECORDS_DATASET_DIR/train
133+
```
134+
135+
Evaluation dataset.
136+
137+
```bash
138+
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/eval no_records_per_shard=10 enable_eval_timesteps=True
139+
```
140+
141+
The evaluation dataset creation takes the first 420 samples of the dataset and adds a timestep field. We then need to manually delete the first 420 samples from the `train` folder so they are not used in training.
142+
143+
144+
```bash
145+
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' rm
146+
```
147+
148+
And verify that they do not exist.
149+
150+
```bash
151+
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' echo
152+
```
153+
154+
After the script is done running, you should see the following directory structure inside `$TFRECORDS_DATASET_DIR`
155+
156+
```
157+
train
158+
eval_timesteps
159+
```
160+
161+
In some instances an empty file `file_42-430.tfrec` is created inside `eval_timesteps`, for sanity check, let's run a delete command.
162+
163+
```bash
164+
rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec
165+
```
166+
167+
### Training on a Single VM
168+
169+
Loading the data is supported both locally from the disk created above, or from `gcs`. In this guide, we'll be using a gcs bucket to train. First copy the data to the GCS bucket.
170+
171+
```bash
172+
BUCKET_NAME=my-bucket
173+
gsutil -m cp -r $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
174+
```
175+
176+
Now run the training command:
177+
178+
```bash
179+
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
180+
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
181+
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
182+
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
183+
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
184+
```
185+
186+
```bash
187+
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
188+
--xla_tpu_megacore_fusion_allow_ags=false \
189+
--xla_enable_async_collective_permute=true \
190+
--xla_tpu_enable_ag_backward_pipelining=true \
191+
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
192+
--xla_tpu_data_parallel_opt_different_sized_ops=true \
193+
--xla_tpu_enable_async_collective_fusion=true \
194+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
195+
--xla_tpu_overlap_compute_collective_tc=true \
196+
--xla_enable_async_all_gather=true \
197+
--xla_tpu_scoped_vmem_limit_kib=65536 \
198+
--xla_tpu_enable_async_all_to_all=true \
199+
--xla_tpu_enable_all_experimental_scheduler_features=true \
200+
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
201+
--xla_tpu_host_transfer_overlap_limit=24 \
202+
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
203+
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
204+
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
205+
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
206+
--xla_max_concurrent_host_send_recv=100 \
207+
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
208+
--xla_latency_hiding_scheduler_rerun=2 \
209+
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
210+
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
211+
--xla_tpu_assign_all_reduce_scatter_layout=true'
212+
```
213+
214+
```bash
215+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
216+
src/maxdiffusion/configs/base_wan_14b.yml \
217+
attention='flash' \
218+
weights_dtype=bfloat16 \
219+
activations_dtype=bfloat16 \
220+
guidance_scale=5.0 \
221+
flow_shift=5.0 \
222+
fps=16 \
223+
skip_jax_distributed_system=False \
224+
run_name=${RUN_NAME} \
225+
output_dir=${OUTPUT_DIR} \
226+
train_data_dir=${DATASET_DIR} \
227+
load_tfrecord_cached=True \
228+
height=1280 \
229+
width=720 \
230+
num_frames=81 \
231+
num_inference_steps=50 \
232+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
233+
max_train_steps=1000 \
234+
enable_profiler=True \
235+
dataset_save_location=${SAVE_DATASET_DIR} \
236+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
237+
flash_min_seq_length=0 \
238+
seed=$RANDOM \
239+
skip_first_n_steps_for_profiler=3 \
240+
profiler_steps=3 \
241+
per_device_batch_size=0.25 \
242+
ici_data_parallelism=1 \
243+
ici_fsdp_parallelism=4 \
244+
ici_tensor_parallelism=1
245+
```
246+
247+
It is important to note a couple of things:
248+
- per_device_batch_size can be a fractional, but must be a whole number when multiplied by number of devices. In this example, 0.25 * 4 (devices) = effective global batch size = 1.
249+
- The step time in v5p-8 with global batch size = 1 is large due to using `FULL` remat. On larger number of chips we can run larger batch sizes greatly increasing MFU, as we will see in the next session of deploying with xpk.
250+
- To enable eval during training set `eval_every` to a value > 0.
251+
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
252+
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
253+
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
254+
255+
You should eventually see a training run as:
256+
257+
```bash
258+
***** Running training *****
259+
Instantaneous batch size per device = 0.25
260+
Total train batch size (w. parallel & distributed) = 1
261+
Total optimization steps = 1000
262+
Calculated TFLOPs per pass: 4893.2719
263+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
264+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
265+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
266+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
267+
completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270
268+
To see full metrics 'tensorboard --logdir=gs://jfacevedo-maxdiffusion-v5p/wan/jfacevedo-wan-v5p-8-17263/tensorboard/'
269+
completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144
270+
completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210
271+
completed step: 3, seconds: 36.016, TFLOP/s/device: 135.864, loss: 0.120
272+
completed step: 4, seconds: 36.008, TFLOP/s/device: 135.894, loss: 0.107
273+
completed step: 5, seconds: 36.008, TFLOP/s/device: 135.895, loss: 0.346
274+
completed step: 6, seconds: 36.006, TFLOP/s/device: 135.900, loss: 0.169
275+
```
276+
277+
### Deploying with XPK
278+
279+
This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).
280+
281+
Using v5p-256 Then the command to run on xpk is as follows:
282+
283+
```bash
284+
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
285+
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
286+
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
287+
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
288+
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
289+
```
290+
291+
```bash
292+
LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
293+
--xla_tpu_megacore_fusion_allow_ags=false \
294+
--xla_enable_async_collective_permute=true \
295+
--xla_tpu_enable_ag_backward_pipelining=true \
296+
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
297+
--xla_tpu_data_parallel_opt_different_sized_ops=true \
298+
--xla_tpu_enable_async_collective_fusion=true \
299+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
300+
--xla_tpu_overlap_compute_collective_tc=true \
301+
--xla_enable_async_all_gather=true \
302+
--xla_tpu_scoped_vmem_limit_kib=65536 \
303+
--xla_tpu_enable_async_all_to_all=true \
304+
--xla_tpu_enable_all_experimental_scheduler_features=true \
305+
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
306+
--xla_tpu_host_transfer_overlap_limit=24 \
307+
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
308+
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
309+
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
310+
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
311+
--xla_max_concurrent_host_send_recv=100 \
312+
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
313+
--xla_latency_hiding_scheduler_rerun=2 \
314+
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
315+
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
316+
--xla_tpu_assign_all_reduce_scatter_layout=true'
317+
```
318+
319+
```bash
320+
python3 ~/xpk/xpk.py workload create \
321+
--cluster=$CLUSTER_NAME \
322+
--project=$PROJECT \
323+
--zone=$ZONE \
324+
--device-type=$DEVICE_TYPE \
325+
--num-slices=1 \
326+
--command=" \
327+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
328+
src/maxdiffusion/configs/base_wan_14b.yml \
329+
attention='flash' \
330+
weights_dtype=bfloat16 \
331+
activations_dtype=bfloat16 \
332+
guidance_scale=5.0 \
333+
flow_shift=5.0 \
334+
fps=16 \
335+
skip_jax_distributed_system=False \
336+
run_name=${RUN_NAME} \
337+
output_dir=${OUTPUT_DIR} \
338+
train_data_dir=${DATASET_DIR} \
339+
load_tfrecord_cached=True \
340+
height=1280 \
341+
width=720 \
342+
num_frames=81 \
343+
num_inference_steps=50 \
344+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
345+
enable_profiler=True \
346+
dataset_save_location=${SAVE_DATASET_DIR} \
347+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
348+
flash_min_seq_length=0 \
349+
seed=$RANDOM \
350+
skip_first_n_steps_for_profiler=3 \
351+
profiler_steps=3 \
352+
per_device_batch_size=0.25 \
353+
ici_data_parallelism=32 \
354+
ici_fsdp_parallelism=4 \
355+
ici_tensor_parallelism=1 \
356+
max_train_steps=5000 \
357+
eval_every=100 \
358+
eval_data_dir=${EVAL_DATA_DIR} \
359+
enable_generate_video_for_eval=True" \
360+
--base-docker-image=${IMAGE_DIR} \
361+
--enable-debug-logs \
362+
--workload=${RUN_NAME} \
363+
--priority=medium \
364+
--max-restarts=0
365+
```
104366

105367
## Flux Training
106368

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
2-
jax>=0.6.2
2+
jax>=0.7.2
33
jaxlib>=0.4.30
44
grain
55
google-cloud-storage>=2.17.0
66
absl-py
77
datasets
8-
flax>=0.11.0
8+
flax>=0.12.0
99
optax>=0.2.3
1010
torch>=2.6.0
1111
torchvision>=0.20.1

requirements_with_jax_ai_image.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# Requirements for Building the MaxDifussion Docker Image
22
# These requirements are additional to the dependencies present in the JAX AI base image.
33
--extra-index-url https://download.pytorch.org/whl/cpu
4-
jax>=0.6.2
4+
jax>=0.7.2
55
jaxlib>=0.4.30
66
grain
77
google-cloud-storage>=2.17.0
88
absl-py
99
datasets
10-
flax>=0.10.2
10+
flax>=0.12.0
1111
optax>=0.2.3
1212
torch>=2.6.0
1313
torchvision>=0.20.1

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ flash_block_sizes: {
8181
# "block_kv_dkv" : 2048,
8282
# "block_kv_dkv_compute" : 2048,
8383
# "block_q_dq" : 3024,
84-
# "block_kv_dq" : 2048
84+
# "block_kv_dq" : 2048,
85+
# "use_fused_bwd_kernel": False,
8586
# }
8687
# GroupNorm groups
8788
norm_num_groups: 32
@@ -237,7 +238,7 @@ global_batch_size: 0
237238
tfrecords_dir: ''
238239
no_records_per_shard: 0
239240
enable_eval_timesteps: False
240-
considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875]
241+
timesteps_list: [125, 250, 375, 500, 625, 750, 875]
241242
num_eval_samples: 420
242243

243244
warmup_steps_fraction: 0.1
@@ -322,6 +323,6 @@ qwix_module_path: ".*"
322323
eval_every: -1
323324
eval_data_dir: ""
324325
enable_generate_video_for_eval: False # This will increase the used TPU memory.
325-
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).
326+
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
326327

327328
enable_ssim: False

src/maxdiffusion/max_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,11 @@ def get_precision(config):
489489
retval = jax.lax.Precision.HIGHEST
490490
return retval
491491

492+
def value_or_none(flash_block_sizes, key):
493+
if key in flash_block_sizes:
494+
return flash_block_sizes[key]
495+
else:
496+
return None
492497

493498
def get_flash_block_sizes(config):
494499
"""Create custom flash attention BlockSizes."""
@@ -501,8 +506,9 @@ def get_flash_block_sizes(config):
501506
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
502507
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
503508
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
504-
block_q_dq=config.flash_block_sizes["block_q_dq"],
505-
block_kv_dq=config.flash_block_sizes["block_kv_dq"],
509+
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
510+
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
511+
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel")
506512
)
507513
return flash_block_sizes
508514

0 commit comments

Comments
 (0)