2525import yaml
2626from . import max_logging
2727from . import max_utils
28- from .models .wan .wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH , WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2928
3029
3130def string_to_bool (s : str ) -> bool :
@@ -42,6 +41,21 @@ def string_to_bool(s: str) -> bool:
4241config = None
4342
4443
44+ def create_parallelisms_list (raw_keys ):
45+ ici_parallelism = [
46+ raw_keys ["ici_data_parallelism" ],
47+ raw_keys ["ici_fsdp_parallelism" ],
48+ raw_keys ["ici_fsdp_transpose_parallelism" ],
49+ raw_keys ["ici_sequence_parallelism" ],
50+ raw_keys ["ici_tensor_parallelism" ],
51+ raw_keys ["ici_tensor_transpose_parallelism" ],
52+ raw_keys ["ici_expert_parallelism" ],
53+ raw_keys ["ici_sequence_parallelism" ],
54+ ]
55+ raw_keys ["ici_parallelism" ] = ici_parallelism
56+ return raw_keys
57+
58+
4559def print_system_information ():
4660 max_logging .log (f"System Information: Jax Version: { jax .__version__ } " )
4761 max_logging .log (f"System Information: Jaxlib Version: { jax .lib .__version__ } " )
@@ -103,7 +117,6 @@ def __init__(self, argv: list[str], **kwargs):
103117 jax .config .update ("jax_compilation_cache_dir" , raw_keys ["jax_cache_dir" ])
104118
105119 _HyperParameters .user_init (raw_keys )
106- _HyperParameters .wan_init (raw_keys )
107120 self .keys = raw_keys
108121 for k in sorted (raw_keys .keys ()):
109122 max_logging .log (f"Config param { k } : { raw_keys [k ]} " )
@@ -112,26 +125,6 @@ def _load_kwargs(self, argv: list[str]):
112125 args_dict = dict (a .split ("=" , 1 ) for a in argv [2 :])
113126 return args_dict
114127
115- @staticmethod
116- def wan_init (raw_keys ):
117- if "wan_transformer_pretrained_model_name_or_path" in raw_keys :
118- transformer_pretrained_model_name_or_path = raw_keys ["wan_transformer_pretrained_model_name_or_path" ]
119- if transformer_pretrained_model_name_or_path == "" :
120- raw_keys ["wan_transformer_pretrained_model_name_or_path" ] = raw_keys ["pretrained_model_name_or_path" ]
121- elif (
122- transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123- or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124- ):
125- # Set correct parameters for CausVid in case of user error.
126- raw_keys ["guidance_scale" ] = 1.0
127- num_inference_steps = raw_keys ["num_inference_steps" ]
128- if num_inference_steps > 10 :
129- max_logging .log (
130- f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting { num_inference_steps } steps."
131- )
132- else :
133- raise ValueError (f"{ transformer_pretrained_model_name_or_path } transformer model is not supported for Wan 2.1" )
134-
135128 @staticmethod
136129 def user_init (raw_keys ):
137130 """Transformations between the config data and configs used at runtime"""
@@ -176,6 +169,8 @@ def user_init(raw_keys):
176169 raw_keys ["total_train_batch_size" ] = max_utils .get_global_batch_size (raw_keys ["per_device_batch_size" ])
177170 raw_keys ["num_slices" ] = get_num_slices (raw_keys )
178171 raw_keys ["quantization_local_shard_count" ] = get_quantization_local_shard_count (raw_keys )
172+ if "ici_fsdp_transpose_parallelism" in raw_keys :
173+ raw_keys = create_parallelisms_list (raw_keys )
179174
180175
181176def get_num_slices (raw_keys ):
0 commit comments