2525import yaml
2626from . import max_logging
2727from . import max_utils
28+ from .models .wan .wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH , WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2829
2930
3031def string_to_bool (s : str ) -> bool :
@@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool:
4142config = None
4243
4344
44- def create_parallelisms_list (raw_keys ):
45- ici_parallelism = [
46- raw_keys ["ici_data_parallelism" ],
47- raw_keys ["ici_fsdp_parallelism" ],
48- raw_keys ["ici_fsdp_transpose_parallelism" ],
49- raw_keys ["ici_sequence_parallelism" ],
50- raw_keys ["ici_tensor_parallelism" ],
51- raw_keys ["ici_tensor_transpose_parallelism" ],
52- raw_keys ["ici_expert_parallelism" ],
53- raw_keys ["ici_sequence_parallelism" ],
54- ]
55- raw_keys ["ici_parallelism" ] = ici_parallelism
56- return raw_keys
57-
58-
5945def print_system_information ():
6046 max_logging .log (f"System Information: Jax Version: { jax .__version__ } " )
6147 max_logging .log (f"System Information: Jaxlib Version: { jax .lib .__version__ } " )
@@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs):
117103 jax .config .update ("jax_compilation_cache_dir" , raw_keys ["jax_cache_dir" ])
118104
119105 _HyperParameters .user_init (raw_keys )
106+ _HyperParameters .wan_init (raw_keys )
120107 self .keys = raw_keys
121108 for k in sorted (raw_keys .keys ()):
122109 max_logging .log (f"Config param { k } : { raw_keys [k ]} " )
@@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]):
125112 args_dict = dict (a .split ("=" , 1 ) for a in argv [2 :])
126113 return args_dict
127114
115+ @staticmethod
116+ def wan_init (raw_keys ):
117+ if "wan_transformer_pretrained_model_name_or_path" in raw_keys :
118+ transformer_pretrained_model_name_or_path = raw_keys ["wan_transformer_pretrained_model_name_or_path" ]
119+ if transformer_pretrained_model_name_or_path == "" :
120+ raw_keys ["wan_transformer_pretrained_model_name_or_path" ] = raw_keys ["pretrained_model_name_or_path" ]
121+ elif (
122+ transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123+ or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124+ ):
125+ # Set correct parameters for CausVid in case of user error.
126+ raw_keys ["guidance_scale" ] = 1.0
127+ num_inference_steps = raw_keys ["num_inference_steps" ]
128+ if num_inference_steps > 10 :
129+ max_logging .log (
130+ f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting { num_inference_steps } steps."
131+ )
132+ else :
133+ raise ValueError (f"{ transformer_pretrained_model_name_or_path } transformer model is not supported for Wan 2.1" )
134+
128135 @staticmethod
129136 def user_init (raw_keys ):
130137 """Transformations between the config data and configs used at runtime"""
@@ -169,8 +176,6 @@ def user_init(raw_keys):
169176 raw_keys ["total_train_batch_size" ] = max_utils .get_global_batch_size (raw_keys ["per_device_batch_size" ])
170177 raw_keys ["num_slices" ] = get_num_slices (raw_keys )
171178 raw_keys ["quantization_local_shard_count" ] = get_quantization_local_shard_count (raw_keys )
172- if "ici_fsdp_transpose_parallelism" in raw_keys :
173- raw_keys = create_parallelisms_list (raw_keys )
174179
175180
176181def get_num_slices (raw_keys ):
@@ -221,4 +226,4 @@ def initialize(argv, **kwargs):
221226if __name__ == "__main__" :
222227 initialize (sys .argv )
223228 print (config .steps )
224- r = range (config .steps )
229+ r = range (config .steps )
0 commit comments