Skip to content

Commit 35a3337

Browse files
committed
added test
2 parents 8a043f6 + 18ec247 commit 35a3337

3 files changed

Lines changed: 7 additions & 36 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,6 @@ weights_dtype: 'bfloat16'
2222
activations_dtype: 'bfloat16'
2323

2424

25-
run_name: ''
26-
output_dir: 'ltx-video-output'
27-
save_config_to_gcs: False
28-
29-
#hardware
30-
hardware: 'tpu'
31-
skip_jax_distributed_system: False
32-
33-
jax_cache_dir: ''
34-
weights_dtype: 'bfloat16'
35-
activations_dtype: 'bfloat16'
36-
37-
3825
run_name: ''
3926
output_dir: 'ltx-video-output'
4027
save_config_to_gcs: False
@@ -78,4 +65,4 @@ per_device_batch_size: 1
7865
compile_topology_num_slices: -1
7966
quantization_local_shard_count: -1
8067
jit_initializers: True
81-
enable_single_replica_ckpt_restoring: False
68+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,12 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
622622
raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.")
623623
# Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim")
624624
# Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py.
625+
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
626+
# ("data", "fsdp", "fsdp_transpose", "expert"),
627+
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
628+
# None,
629+
# None,
630+
# )
625631
qkvo_sharding_spec = jax.sharding.PartitionSpec(
626632
"data",
627633
"fsdp",

src/maxdiffusion/pyconfig.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28-
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2928

3029

3130
def string_to_bool(s: str) -> bool:
@@ -118,7 +117,6 @@ def __init__(self, argv: list[str], **kwargs):
118117
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])
119118

120119
_HyperParameters.user_init(raw_keys)
121-
_HyperParameters.wan_init(raw_keys)
122120
self.keys = raw_keys
123121
for k in sorted(raw_keys.keys()):
124122
max_logging.log(f"Config param {k}: {raw_keys[k]}")
@@ -127,26 +125,6 @@ def _load_kwargs(self, argv: list[str]):
127125
args_dict = dict(a.split("=", 1) for a in argv[2:])
128126
return args_dict
129127

130-
@staticmethod
131-
def wan_init(raw_keys):
132-
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
133-
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
134-
if transformer_pretrained_model_name_or_path == "":
135-
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
136-
elif (
137-
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
138-
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
139-
):
140-
# Set correct parameters for CausVid in case of user error.
141-
raw_keys["guidance_scale"] = 1.0
142-
num_inference_steps = raw_keys["num_inference_steps"]
143-
if num_inference_steps > 10:
144-
max_logging.log(
145-
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
146-
)
147-
else:
148-
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
149-
150128
@staticmethod
151129
def user_init(raw_keys):
152130
"""Transformations between the config data and configs used at runtime"""

0 commit comments

Comments
 (0)