Skip to content

Commit a1ad421

Browse files
authored
Update pyconfig.py
1 parent aa7befd commit a1ad421

1 file changed

Lines changed: 17 additions & 22 deletions

File tree

src/maxdiffusion/pyconfig.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28-
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2928

3029

3130
def string_to_bool(s: str) -> bool:
@@ -42,6 +41,21 @@ def string_to_bool(s: str) -> bool:
4241
config = None
4342

4443

44+
def create_parallelisms_list(raw_keys):
45+
ici_parallelism = [
46+
raw_keys["ici_data_parallelism"],
47+
raw_keys["ici_fsdp_parallelism"],
48+
raw_keys["ici_fsdp_transpose_parallelism"],
49+
raw_keys["ici_sequence_parallelism"],
50+
raw_keys["ici_tensor_parallelism"],
51+
raw_keys["ici_tensor_transpose_parallelism"],
52+
raw_keys["ici_expert_parallelism"],
53+
raw_keys["ici_sequence_parallelism"],
54+
]
55+
raw_keys["ici_parallelism"] = ici_parallelism
56+
return raw_keys
57+
58+
4559
def print_system_information():
4660
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
4761
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
@@ -103,7 +117,6 @@ def __init__(self, argv: list[str], **kwargs):
103117
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])
104118

105119
_HyperParameters.user_init(raw_keys)
106-
_HyperParameters.wan_init(raw_keys)
107120
self.keys = raw_keys
108121
for k in sorted(raw_keys.keys()):
109122
max_logging.log(f"Config param {k}: {raw_keys[k]}")
@@ -112,26 +125,6 @@ def _load_kwargs(self, argv: list[str]):
112125
args_dict = dict(a.split("=", 1) for a in argv[2:])
113126
return args_dict
114127

115-
@staticmethod
116-
def wan_init(raw_keys):
117-
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
118-
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119-
if transformer_pretrained_model_name_or_path == "":
120-
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121-
elif (
122-
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123-
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124-
):
125-
# Set correct parameters for CausVid in case of user error.
126-
raw_keys["guidance_scale"] = 1.0
127-
num_inference_steps = raw_keys["num_inference_steps"]
128-
if num_inference_steps > 10:
129-
max_logging.log(
130-
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
131-
)
132-
else:
133-
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
134-
135128
@staticmethod
136129
def user_init(raw_keys):
137130
"""Transformations between the config data and configs used at runtime"""
@@ -176,6 +169,8 @@ def user_init(raw_keys):
176169
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
177170
raw_keys["num_slices"] = get_num_slices(raw_keys)
178171
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172+
if "ici_fsdp_transpose_parallelism" in raw_keys:
173+
raw_keys = create_parallelisms_list(raw_keys)
179174

180175

181176
def get_num_slices(raw_keys):

0 commit comments

Comments
 (0)