Skip to content

Commit d06dee3

Browse files
committed
changed back pyconfig
1 parent e873a17 commit d06dee3

2 files changed

Lines changed: 24 additions & 95 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True):
257257
if devices is None:
258258
devices = jax.devices()
259259
num_devices = len(devices)
260-
##special case for ltx-video
261-
if "fsdp_transpose" in config.mesh_axes:
262-
num_slices = 1
263-
# if config.inference_benchmark_test else config.num_slices
264-
num_devices_per_slice = num_devices // num_slices
265-
# Find possible unspecified parallelisms
266-
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267-
mesh = mesh_utils.create_device_mesh(
268-
ici_parallelism,
269-
devices,
270-
)
271-
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272-
273-
return mesh
274-
275260
try:
276261
num_slices = 1 + max([d.slice_index for d in devices])
277262
except:
@@ -303,66 +288,9 @@ def create_device_mesh(config, devices=None, logging=True):
303288
if logging:
304289
max_logging.log(f"Decided on mesh: {mesh}")
305290

306-
307-
308-
309-
310-
311-
312-
313-
314-
315-
316-
317-
318-
319-
320-
321-
322-
323-
324291
return mesh
325292

326293

327-
328-
329-
330-
331-
332-
333-
334-
335-
336-
337-
338-
339-
340-
341-
342-
343-
344-
345-
346-
347-
348-
349-
350-
351-
352-
353-
354-
355-
356-
357-
358-
359-
360-
361-
362-
363-
364-
365-
366294
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
367295
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
368296
@@ -474,11 +402,7 @@ def setup_initial_state(
474402
config.enable_single_replica_ckpt_restoring,
475403
)
476404
if state:
477-
###!Edited
478-
if checkpoint_item == " ":
479-
state = state
480-
else:
481-
state = state[checkpoint_item]
405+
state = state[checkpoint_item]
482406
if not state:
483407
max_logging.log(f"Could not find the item in orbax, creating state...")
484408
init_train_state_partial = functools.partial(

src/maxdiffusion/pyconfig.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28+
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2829

2930

3031
def string_to_bool(s: str) -> bool:
@@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool:
4142
config = None
4243

4344

44-
def create_parallelisms_list(raw_keys):
45-
ici_parallelism = [
46-
raw_keys["ici_data_parallelism"],
47-
raw_keys["ici_fsdp_parallelism"],
48-
raw_keys["ici_fsdp_transpose_parallelism"],
49-
raw_keys["ici_sequence_parallelism"],
50-
raw_keys["ici_tensor_parallelism"],
51-
raw_keys["ici_tensor_transpose_parallelism"],
52-
raw_keys["ici_expert_parallelism"],
53-
raw_keys["ici_sequence_parallelism"],
54-
]
55-
raw_keys["ici_parallelism"] = ici_parallelism
56-
return raw_keys
57-
58-
5945
def print_system_information():
6046
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
6147
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
@@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs):
117103
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])
118104

119105
_HyperParameters.user_init(raw_keys)
106+
_HyperParameters.wan_init(raw_keys)
120107
self.keys = raw_keys
121108
for k in sorted(raw_keys.keys()):
122109
max_logging.log(f"Config param {k}: {raw_keys[k]}")
@@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]):
125112
args_dict = dict(a.split("=", 1) for a in argv[2:])
126113
return args_dict
127114

115+
@staticmethod
116+
def wan_init(raw_keys):
117+
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
118+
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119+
if transformer_pretrained_model_name_or_path == "":
120+
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121+
elif (
122+
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123+
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124+
):
125+
# Set correct parameters for CausVid in case of user error.
126+
raw_keys["guidance_scale"] = 1.0
127+
num_inference_steps = raw_keys["num_inference_steps"]
128+
if num_inference_steps > 10:
129+
max_logging.log(
130+
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
131+
)
132+
else:
133+
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
134+
128135
@staticmethod
129136
def user_init(raw_keys):
130137
"""Transformations between the config data and configs used at runtime"""
@@ -169,8 +176,6 @@ def user_init(raw_keys):
169176
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
170177
raw_keys["num_slices"] = get_num_slices(raw_keys)
171178
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172-
if "ici_fsdp_transpose_parallelism" in raw_keys:
173-
raw_keys = create_parallelisms_list(raw_keys)
174179

175180

176181
def get_num_slices(raw_keys):
@@ -221,4 +226,4 @@ def initialize(argv, **kwargs):
221226
if __name__ == "__main__":
222227
initialize(sys.argv)
223228
print(config.steps)
224-
r = range(config.steps)
229+
r = range(config.steps)

0 commit comments

Comments
 (0)