Skip to content

Commit 0f8483e

Browse files
committed
pulled
1 parent 8bf24a3 commit 0f8483e

4 files changed

Lines changed: 44 additions & 66 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,29 @@ output_dir: 'ltx-video-output'
2727
save_config_to_gcs: False
2828

2929
#parallelism
30-
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
30+
mesh_axes: ['data', 'fsdp', 'tensor']
3131
logical_axis_rules: [
3232
['batch', 'data'],
33+
['activation_heads', 'fsdp'],
3334
['activation_batch', ['data','fsdp']],
34-
['activation_heads', 'tensor'],
3535
['activation_kv', 'tensor'],
3636
['mlp','tensor'],
3737
['embed','fsdp'],
3838
['heads', 'tensor'],
39+
['norm', 'fsdp'],
3940
['conv_batch', ['data','fsdp']],
4041
['out_channels', 'tensor'],
4142
['conv_out', 'fsdp'],
43+
['conv_in', 'fsdp']
4244
]
43-
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
45+
data_sharding: [['data', 'fsdp', 'tensor']]
4446
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
4547
dcn_fsdp_parallelism: -1
4648
dcn_tensor_parallelism: 1
47-
ici_data_parallelism: -1
48-
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
49+
ici_data_parallelism: 1
50+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
4951
ici_tensor_parallelism: 1
5052

51-
ici_fsdp_transpose_parallelism: 1
52-
ici_sequence_parallelism: 1
53-
ici_tensor_transpose_parallelism: 1
54-
ici_expert_parallelism: 1
55-
ici_sequence_parallelism: 1
56-
5753

5854

5955

src/maxdiffusion/max_utils.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True):
257257
if devices is None:
258258
devices = jax.devices()
259259
num_devices = len(devices)
260-
##special case for ltx-video
261-
if "fsdp_transpose" in config.mesh_axes:
262-
num_slices = 1
263-
# if config.inference_benchmark_test else config.num_slices
264-
num_devices_per_slice = num_devices // num_slices
265-
# Find possible unspecified parallelisms
266-
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267-
mesh = mesh_utils.create_device_mesh(
268-
ici_parallelism,
269-
devices,
270-
)
271-
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272-
273-
return mesh
274-
275260
try:
276261
num_slices = 1 + max([d.slice_index for d in devices])
277262
except:
@@ -417,11 +402,7 @@ def setup_initial_state(
417402
config.enable_single_replica_ckpt_restoring,
418403
)
419404
if state:
420-
###!Edited
421-
if checkpoint_item == " ":
422-
state = state
423-
else:
424-
state = state[checkpoint_item]
405+
state = state[checkpoint_item]
425406
if not state:
426407
max_logging.log(f"Could not find the item in orbax, creating state...")
427408
init_train_state_partial = functools.partial(
@@ -628,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
628609
initialize_jax_for_gpu()
629610
max_logging.log("Jax distributed system initialized on GPU!")
630611
else:
631-
jax.distributed.initialize()
612+
jax.distributed.initialize()

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
from typing import Any, Dict, Optional, Tuple
2020
from enum import Enum, auto
21+
2122
import jax
2223
import jax.nn as jnn
2324
import jax.numpy as jnp
@@ -213,7 +214,8 @@ def __call__(
213214

214215
# Adaptive Norm
215216
if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
216-
assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
217+
# [batch, 1 or num_tokens, embedding_dim]
218+
assert timestep.ndim == 3
217219
num_ada_params = self.scale_shift_table.shape[0]
218220
ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape(
219221
batch_size, timestep.shape[1], num_ada_params, -1
@@ -452,7 +454,7 @@ def __call__(
452454
deterministic: bool = True,
453455
**cross_attention_kwargs,
454456
) -> jnp.ndarray:
455-
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa: F821
457+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821
456458
assert cross_attention_kwargs.get("scale", None) is None, "Not supported"
457459

458460
input_axis_names = ("activation_batch", "activation_length", "activation_embed")
@@ -636,27 +638,20 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
636638
raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.")
637639
# Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim")
638640
# Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py.
639-
qkvo_sharding_spec = jax.sharding.PartitionSpec(
640-
("data", "fsdp", "fsdp_transpose", "expert"),
641-
("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
642-
None,
643-
None,
644-
)
645641
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
646642
# ("data", "fsdp", "fsdp_transpose", "expert"),
647643
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
648644
# None,
649645
# None,
650646
# )
651-
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
652-
# None,
653-
# None,
654-
# None,
655-
# None,
656-
# )
647+
qkvo_sharding_spec = jax.sharding.PartitionSpec(
648+
"data",
649+
"fsdp",
650+
None,
651+
"tensor",
652+
)
657653
# Based on: ("activation_kv_batch", "activation_length")
658-
qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
659-
# qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None)
654+
qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None)
660655
wrapped_flash_attention = shard_map(
661656
partial_flash_attention,
662657
mesh=sharding_mesh,
@@ -841,7 +836,8 @@ def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic:
841836
inner_dim = dim * self.mult
842837
if inner_dim < 256:
843838
raise ValueError("inner_dim must be at least 256")
844-
inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256
839+
# round to nearest multiple of 256
840+
inner_dim = round(inner_dim / 256) * 256
845841
else:
846842
inner_dim = self.inner_dim
847843

src/maxdiffusion/pyconfig.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28+
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2829

2930

3031
def string_to_bool(s: str) -> bool:
@@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool:
4142
config = None
4243

4344

44-
def create_parallelisms_list(raw_keys):
45-
ici_parallelism = [
46-
raw_keys["ici_data_parallelism"],
47-
raw_keys["ici_fsdp_parallelism"],
48-
raw_keys["ici_fsdp_transpose_parallelism"],
49-
raw_keys["ici_sequence_parallelism"],
50-
raw_keys["ici_tensor_parallelism"],
51-
raw_keys["ici_tensor_transpose_parallelism"],
52-
raw_keys["ici_expert_parallelism"],
53-
raw_keys["ici_sequence_parallelism"],
54-
]
55-
raw_keys["ici_parallelism"] = ici_parallelism
56-
return raw_keys
57-
58-
5945
def print_system_information():
6046
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
6147
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
@@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs):
117103
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])
118104

119105
_HyperParameters.user_init(raw_keys)
106+
_HyperParameters.wan_init(raw_keys)
120107
self.keys = raw_keys
121108
for k in sorted(raw_keys.keys()):
122109
max_logging.log(f"Config param {k}: {raw_keys[k]}")
@@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]):
125112
args_dict = dict(a.split("=", 1) for a in argv[2:])
126113
return args_dict
127114

115+
@staticmethod
116+
def wan_init(raw_keys):
117+
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
118+
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119+
if transformer_pretrained_model_name_or_path == "":
120+
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121+
elif (
122+
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123+
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124+
):
125+
# Set correct parameters for CausVid in case of user error.
126+
raw_keys["guidance_scale"] = 1.0
127+
num_inference_steps = raw_keys["num_inference_steps"]
128+
if num_inference_steps > 10:
129+
max_logging.log(
130+
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
131+
)
132+
else:
133+
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
134+
128135
@staticmethod
129136
def user_init(raw_keys):
130137
"""Transformations between the config data and configs used at runtime"""
@@ -169,8 +176,6 @@ def user_init(raw_keys):
169176
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
170177
raw_keys["num_slices"] = get_num_slices(raw_keys)
171178
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172-
if "ici_fsdp_transpose_parallelism" in raw_keys:
173-
raw_keys = create_parallelisms_list(raw_keys)
174179

175180

176181
def get_num_slices(raw_keys):

0 commit comments

Comments
 (0)