1+ export PYTHONPATH=" /home/sanbao_google_com/maxdiffusion/src:$PYTHONPATH "
2+
3+ RUN_NAME=sanbao-v5p-test-${RANDOM}
4+ OUTPUT_DIR=gs://sanbao-bucket/wan/sanbao-v5p-test
5+ DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/
6+ EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/
7+ SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/
8+ RANDOM=123456789
9+ CKPT_PATH=gs://sanbao-bucket/wan_ckp
10+
11+ export LIBTPU_INIT_ARGS=' --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
12+ --xla_tpu_megacore_fusion_allow_ags=false \
13+ --xla_enable_async_collective_permute=true \
14+ --xla_tpu_enable_ag_backward_pipelining=true \
15+ --xla_tpu_enable_data_parallel_all_reduce_opt=true \
16+ --xla_tpu_data_parallel_opt_different_sized_ops=true \
17+ --xla_tpu_enable_async_collective_fusion=true \
18+ --xla_tpu_enable_async_collective_fusion_multiple_steps=true \
19+ --xla_tpu_overlap_compute_collective_tc=true \
20+ --xla_enable_async_all_gather=true \
21+ --xla_tpu_scoped_vmem_limit_kib=65536 \
22+ --xla_tpu_enable_async_all_to_all=true \
23+ --xla_tpu_enable_all_experimental_scheduler_features=true \
24+ --xla_tpu_enable_scheduler_memory_pressure_tracking=true \
25+ --xla_tpu_host_transfer_overlap_limit=24 \
26+ --xla_tpu_aggressive_opt_barrier_removal=ENABLED \
27+ --xla_lhs_prioritize_async_depth_over_stall=ENABLED \
28+ --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
29+ --xla_should_add_loop_invariant_op_in_chain=ENABLED \
30+ --xla_max_concurrent_host_send_recv=100 \
31+ --xla_tpu_scheduler_percent_shared_memory_limit=100 \
32+ --xla_latency_hiding_scheduler_rerun=2 \
33+ --xla_tpu_use_minor_sharding_for_major_trivial_input=true \
34+ --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
35+ --xla_tpu_assign_all_reduce_scatter_layout=true'
36+
37+ HF_HUB_CACHE=/dev/shm/ python3 -m src.maxdiffusion.train_wan \
38+ src/maxdiffusion/configs/base_wan_14b.yml \
39+ attention=' flash' \
40+ weights_dtype=bfloat16 \
41+ activations_dtype=bfloat16 \
42+ guidance_scale=5.0 \
43+ flow_shift=5.0 \
44+ fps=16 \
45+ skip_jax_distributed_system=False \
46+ run_name=${RUN_NAME} \
47+ output_dir=${OUTPUT_DIR} \
48+ train_data_dir=${DATASET_DIR} \
49+ load_tfrecord_cached=True \
50+ height=1280 \
51+ width=720 \
52+ num_frames=81 \
53+ num_inference_steps=50 \
54+ jax_cache_dir=${OUTPUT_DIR} /jax_cache/ \
55+ max_train_steps=20 \
56+ enable_profiler=True \
57+ dataset_save_location=${SAVE_DATASET_DIR} \
58+ remat_policy=' HIDDEN_STATE_WITH_OFFLOAD' \
59+ flash_min_seq_length=0 \
60+ seed=$RANDOM \
61+ skip_first_n_steps_for_profiler=3 \
62+ profiler_steps=3 \
63+ per_device_batch_size=0.5 \
64+ ici_data_parallelism=2 \
65+ ici_fsdp_parallelism=2 \
66+ ici_tensor_parallelism=1 \
67+ enable_ssim=False \
68+ flash_block_sizes=' {\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}'
0 commit comments