Skip to content

Commit 546ecab

Browse files
committed
ruff fixed
1 parent 35a3337 commit 546ecab

4 files changed

Lines changed: 27 additions & 45 deletions

File tree

src/maxdiffusion/generate_ltx_video.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ def run_inference(
6464
segment_ids=segment_ids,
6565
encoder_attention_segment_ids=encoder_attention_segment_ids,
6666
)
67-
prof = profiler.Profiler(config)
68-
prof.activate(optional_postfix="transformer step")
69-
prof.deactivate()
70-
7167

7268
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
7369
noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
@@ -176,8 +172,8 @@ def run(config):
176172
in_shardings=(state_shardings,),
177173
out_shardings=None,
178174
)
179-
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
180-
noise_pred = p_run_inference(states).block_until_ready()
175+
176+
noise_pred = p_run_inference(states).block_until_ready()
181177
print(noise_pred) # (4, 256, 128)
182178

183179

src/maxdiffusion/max_utils.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True):
257257
if devices is None:
258258
devices = jax.devices()
259259
num_devices = len(devices)
260-
##special case for ltx-video
261-
if "fsdp_transpose" in config.mesh_axes:
262-
num_slices = 1
263-
# if config.inference_benchmark_test else config.num_slices
264-
num_devices_per_slice = num_devices // num_slices
265-
# Find possible unspecified parallelisms
266-
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267-
mesh = mesh_utils.create_device_mesh(
268-
ici_parallelism,
269-
devices,
270-
)
271-
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272-
273-
return mesh
274-
275260
try:
276261
num_slices = 1 + max([d.slice_index for d in devices])
277262
except:
@@ -417,11 +402,7 @@ def setup_initial_state(
417402
config.enable_single_replica_ckpt_restoring,
418403
)
419404
if state:
420-
###!Edited
421-
if checkpoint_item == " ":
422-
state = state
423-
else:
424-
state = state[checkpoint_item]
405+
state = state[checkpoint_item]
425406
if not state:
426407
max_logging.log(f"Could not find the item in orbax, creating state...")
427408
init_train_state_partial = functools.partial(
@@ -628,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
628609
initialize_jax_for_gpu()
629610
max_logging.log("Jax distributed system initialized on GPU!")
630611
else:
631-
jax.distributed.initialize()
612+
jax.distributed.initialize()

src/maxdiffusion/pyconfig.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28+
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2829

2930

3031
def string_to_bool(s: str) -> bool:
@@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool:
4142
config = None
4243

4344

44-
def create_parallelisms_list(raw_keys):
45-
ici_parallelism = [
46-
raw_keys["ici_data_parallelism"],
47-
raw_keys["ici_fsdp_parallelism"],
48-
raw_keys["ici_fsdp_transpose_parallelism"],
49-
raw_keys["ici_sequence_parallelism"],
50-
raw_keys["ici_tensor_parallelism"],
51-
raw_keys["ici_tensor_transpose_parallelism"],
52-
raw_keys["ici_expert_parallelism"],
53-
raw_keys["ici_sequence_parallelism"],
54-
]
55-
raw_keys["ici_parallelism"] = ici_parallelism
56-
return raw_keys
57-
58-
5945
def print_system_information():
6046
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
6147
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
@@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs):
117103
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])
118104

119105
_HyperParameters.user_init(raw_keys)
106+
_HyperParameters.wan_init(raw_keys)
120107
self.keys = raw_keys
121108
for k in sorted(raw_keys.keys()):
122109
max_logging.log(f"Config param {k}: {raw_keys[k]}")
@@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]):
125112
args_dict = dict(a.split("=", 1) for a in argv[2:])
126113
return args_dict
127114

115+
@staticmethod
116+
def wan_init(raw_keys):
117+
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
118+
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119+
if transformer_pretrained_model_name_or_path == "":
120+
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121+
elif (
122+
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123+
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124+
):
125+
# Set correct parameters for CausVid in case of user error.
126+
raw_keys["guidance_scale"] = 1.0
127+
num_inference_steps = raw_keys["num_inference_steps"]
128+
if num_inference_steps > 10:
129+
max_logging.log(
130+
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
131+
)
132+
else:
133+
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
134+
128135
@staticmethod
129136
def user_init(raw_keys):
130137
"""Transformations between the config data and configs used at runtime"""
@@ -169,8 +176,6 @@ def user_init(raw_keys):
169176
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
170177
raw_keys["num_slices"] = get_num_slices(raw_keys)
171178
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172-
if "ici_fsdp_transpose_parallelism" in raw_keys:
173-
raw_keys = create_parallelisms_list(raw_keys)
174179

175180

176181
def get_num_slices(raw_keys):

src/maxdiffusion/tests/ltx_transformer_step_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_one_step_transformer(self):
191191
in_shardings=(state_shardings,),
192192
out_shardings=None,
193193
)
194-
194+
195195
noise_pred = p_run_inference(states).block_until_ready()
196196
noise_pred = torch.from_numpy(np.array(noise_pred))
197197

0 commit comments

Comments
 (0)