Skip to content

Commit b293f3c

Browse files
author
Juan Acevedo
committed
fix unit tests.
1 parent b9019f8 commit b293f3c

3 files changed

Lines changed: 16 additions & 15 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ log_period: 100
3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
3131

3232
# Overrides the transformer from pretrained_model_name_or_path
33-
transformer_pretrained_model_name_or_path: ''
33+
wan_transformer_pretrained_model_name_or_path: ''
3434

3535
unet_checkpoint: ''
3636
revision: ''

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
9595
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
9696
# This helps with loading sharded weights directly into the accelerators without fist copying them
9797
# all to one device and then distributing them, thus using low HBM memory.
98-
params = load_wan_transformer(config.transformer_pretrained_model_name_or_path, params, "cpu")
98+
params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu")
9999
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
100100
for path, val in flax.traverse_util.flatten_dict(params).items():
101101
sharding = logical_state_sharding[path].value

src/maxdiffusion/pyconfig.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,20 @@ def _load_kwargs(self, argv: list[str]):
114114

115115
@staticmethod
116116
def wan_init(raw_keys):
117-
transformer_pretrained_model_name_or_path = raw_keys["transformer_pretrained_model_name_or_path"]
118-
if transformer_pretrained_model_name_or_path == "":
119-
raw_keys["transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
120-
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
121-
# Set correct parameters for CausVid in case of user error.
122-
raw_keys["guidance_scale"] = 1.0
123-
num_inference_steps = raw_keys["num_inference_steps"]
124-
if num_inference_steps > 10:
125-
max_logging.log(
126-
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
127-
)
128-
else:
129-
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
117+
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
118+
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119+
if transformer_pretrained_model_name_or_path == "":
120+
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121+
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
122+
# Set correct parameters for CausVid in case of user error.
123+
raw_keys["guidance_scale"] = 1.0
124+
num_inference_steps = raw_keys["num_inference_steps"]
125+
if num_inference_steps > 10:
126+
max_logging.log(
127+
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
128+
)
129+
else:
130+
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
130131

131132
@staticmethod
132133
def user_init(raw_keys):

0 commit comments

Comments
 (0)