1+
2+ #! /bin/bash
3+ # bash docker_build_dependency_image.sh
4+ # docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
5+ # docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
6+ CLUSTER_NAME=bodaborg-tpu7x-128
7+ DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256
8+ PROJECT=cloud-tpu-multipod-dev
9+ ZONE=us-central1
10+
11+ # Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path.
12+ export RUN_NAME=wan-v7x-128-incre-bring-back-v0
13+
14+ USR_NAME=elisatsai
15+ YOUR_GCS_BUCKET=gs://${USR_NAME} -wan-maxdiffusion
16+
17+ OUTPUT_DIR=${YOUR_GCS_BUCKET} /wan/${RUN_NAME}
18+
19+ # using sanbao's dir
20+ DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/
21+ EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/
22+ SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/
23+
24+ RANDOM=123456789
25+ IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
26+ LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl
27+
28+ xpk workload create \
29+ --cluster=$CLUSTER_NAME \
30+ --project=$PROJECT \
31+ --zone=$ZONE \
32+ --device-type=$DEVICE_TYPE \
33+ --num-slices=1 \
34+ --command=" \
35+ pip install . && \
36+ gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \
37+ python -m pip install ${LIBTPU_VERSION} && \
38+ pip install tokamax && \
39+ export XLA_FLAGS='--xla_dump_to=/tmp/xla_dumps --xla_dump_hlo_as_text --xla_dump_hlo_as_proto' && \
40+ export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \
41+ --xla_tpu_enable_async_collective_fusion=true \
42+ --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
43+ --xla_enable_async_all_reduce=true \
44+ --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
45+ --xla_max_concurrent_async_all_gathers=4 \
46+ --xla_tpu_enable_async_all_to_all=true \
47+ --xla_latency_hiding_scheduler_rerun=5 \
48+ --xla_tpu_rwb_fusion=false \
49+ --xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \
50+ --xla_tpu_impure_enable_packed_bf16_math_ops=false \
51+ --xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
52+ --xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
53+ --xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
54+ --xla_tpu_enable_all_gather_offload_tracing=true \
55+ --xla_tpu_use_tc_device_shape_on_sc=true \
56+ --xla_tpu_prefer_async_allgather_to_allreduce=true \
57+ --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
58+ --xla_tpu_scoped_vmem_limit_kib=65536 \
59+ --xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \
60+ --xla_enable_transpose_trace=false' && \
61+ export HF_TOKEN=<your_token_here> && \
62+ echo 'Starting WAN training ...' && \
63+ HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \
64+ src/maxdiffusion/configs/base_wan_14b.yml \
65+ attention='flash' \
66+ weights_dtype=bfloat16 \
67+ activations_dtype=bfloat16 \
68+ guidance_scale=5.0 \
69+ flow_shift=5.0 \
70+ fps=16 \
71+ skip_jax_distributed_system=False \
72+ run_name='test-wan-training-new' \
73+ output_dir=${OUTPUT_DIR} \
74+ train_data_dir=${DATASET_DIR} \
75+ load_tfrecord_cached=True \
76+ height=1280 \
77+ width=720 \
78+ num_frames=81 \
79+ num_inference_steps=50 \
80+ prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \
81+ jax_cache_dir=${OUTPUT_DIR} /jax_cache/ \
82+ enable_profiler=True \
83+ dataset_save_location=${SAVE_DATASET_DIR} \
84+ remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
85+ flash_min_seq_length=0 \
86+ seed=$RANDOM \
87+ skip_first_n_steps_for_profiler=0 \
88+ profiler_steps=2 \
89+ per_device_batch_size=0.5 \
90+ ici_data_parallelism=64 \
91+ ici_fsdp_parallelism=2 \
92+ ici_tensor_parallelism=1 \
93+ allow_split_physical_axes=True \
94+ max_train_steps=2 \
95+ scan_layers=true \
96+ flash_block_sizes='{\" block_q\" :2048,\" block_kv_compute\" :512,\" block_kv\" :2048,\" block_q_dkv\" :2048,\" block_kv_dkv\" :2048,\" block_kv_dkv_compute\" :512,\" use_fused_bwd_kernel\" :true}' || (echo 'Training failed, uploading HLO dumps...'; sleep 5; gsutil -m cp -r /tmp/xla_dumps ${OUTPUT_DIR} /hlo_dumps/ 2>&1); \
97+ " \
98+ --base-docker-image=${IMAGE_DIR} \
99+ --enable-debug-logs \
100+ --workload=${RUN_NAME} \
101+ --priority=medium \
102+ --max-restarts=0
0 commit comments