Skip to content

Commit 4587ff8

Browse files
lint/format files.
1 parent 59e1932 commit 4587ff8

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def rename_for_nnx(key):
2828
new_key = key[:-1] + ("scale",)
2929
return new_key
3030

31+
3132
def rename_for_custom_trasformer(key):
3233
renamed_pt_key = key.replace("model.diffusion_model.", "")
3334

@@ -53,6 +54,7 @@ def rename_for_custom_trasformer(key):
5354

5455
return renamed_pt_key
5556

57+
5658
def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
5759
device = jax.devices(device)[0]
5860
with jax.default_device(device):
@@ -74,7 +76,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
7476
random_flax_state_dict[string_tuple] = flattened_dict[key]
7577
for pt_key, tensor in tensors.items():
7678
renamed_pt_key = rename_key(pt_key)
77-
79+
7880
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)
7981

8082
pt_tuple_key = tuple(renamed_pt_key.split("."))
@@ -89,6 +91,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
8991
jax.clear_caches()
9092
return flax_state_dict
9193

94+
9295
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
9396
device = jax.devices(device)[0]
9497
with jax.default_device(device):

src/maxdiffusion/pyconfig.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def wan_init(raw_keys):
118118
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119119
if transformer_pretrained_model_name_or_path == "":
120120
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121-
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
121+
elif (
122+
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123+
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124+
):
122125
# Set correct parameters for CausVid in case of user error.
123126
raw_keys["guidance_scale"] = 1.0
124127
num_inference_steps = raw_keys["num_inference_steps"]

0 commit comments

Comments
 (0)