Skip to content

Commit dd36fd5

Browse files
committed
add sh
1 parent 0bbf7f0 commit dd36fd5

1 file changed

Lines changed: 68 additions & 0 deletions

File tree

run_wan_on_vm.sh

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
export PYTHONPATH="/home/sanbao_google_com/maxdiffusion/src:$PYTHONPATH"
2+
3+
RUN_NAME=sanbao-v5p-test-${RANDOM}
4+
OUTPUT_DIR=gs://sanbao-bucket/wan/sanbao-v5p-test
5+
DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/
6+
EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/
7+
SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/
8+
RANDOM=123456789
9+
CKPT_PATH=gs://sanbao-bucket/wan_ckp
10+
11+
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
12+
--xla_tpu_megacore_fusion_allow_ags=false \
13+
--xla_enable_async_collective_permute=true \
14+
--xla_tpu_enable_ag_backward_pipelining=true \
15+
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
16+
--xla_tpu_data_parallel_opt_different_sized_ops=true \
17+
--xla_tpu_enable_async_collective_fusion=true \
18+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
19+
--xla_tpu_overlap_compute_collective_tc=true \
20+
--xla_enable_async_all_gather=true \
21+
--xla_tpu_scoped_vmem_limit_kib=65536 \
22+
--xla_tpu_enable_async_all_to_all=true \
23+
--xla_tpu_enable_all_experimental_scheduler_features=true \
24+
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
25+
--xla_tpu_host_transfer_overlap_limit=24 \
26+
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
27+
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
28+
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
29+
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
30+
--xla_max_concurrent_host_send_recv=100 \
31+
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
32+
--xla_latency_hiding_scheduler_rerun=2 \
33+
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
34+
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
35+
--xla_tpu_assign_all_reduce_scatter_layout=true'
36+
37+
HF_HUB_CACHE=/dev/shm/ python3 -m src.maxdiffusion.train_wan \
38+
src/maxdiffusion/configs/base_wan_14b.yml \
39+
attention='flash' \
40+
weights_dtype=bfloat16 \
41+
activations_dtype=bfloat16 \
42+
guidance_scale=5.0 \
43+
flow_shift=5.0 \
44+
fps=16 \
45+
skip_jax_distributed_system=False \
46+
run_name=${RUN_NAME} \
47+
output_dir=${OUTPUT_DIR} \
48+
train_data_dir=${DATASET_DIR} \
49+
load_tfrecord_cached=True \
50+
height=1280 \
51+
width=720 \
52+
num_frames=81 \
53+
num_inference_steps=50 \
54+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
55+
max_train_steps=20 \
56+
enable_profiler=True \
57+
dataset_save_location=${SAVE_DATASET_DIR} \
58+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
59+
flash_min_seq_length=0 \
60+
seed=$RANDOM \
61+
skip_first_n_steps_for_profiler=3 \
62+
profiler_steps=3 \
63+
per_device_batch_size=0.5 \
64+
ici_data_parallelism=2 \
65+
ici_fsdp_parallelism=2 \
66+
ici_tensor_parallelism=1 \
67+
enable_ssim=False \
68+
flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}'

0 commit comments

Comments
 (0)