Skip to content

Commit 59e1932

Browse files
add fusion x support.
1 parent 31dcb6c commit 59e1932

2 files changed

Lines changed: 66 additions & 21 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
1010

1111
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
12+
WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
1213

1314

1415
def _tuple_str_to_int(in_tuple):
@@ -27,6 +28,66 @@ def rename_for_nnx(key):
2728
new_key = key[:-1] + ("scale",)
2829
return new_key
2930

31+
def rename_for_custom_trasformer(key):
32+
renamed_pt_key = key.replace("model.diffusion_model.", "")
33+
34+
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
35+
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
36+
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
37+
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
38+
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
39+
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
40+
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")
41+
42+
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
43+
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
44+
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
45+
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
46+
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
47+
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
48+
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
49+
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
50+
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
51+
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
52+
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
53+
54+
return renamed_pt_key
55+
56+
def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
57+
device = jax.devices(device)[0]
58+
with jax.default_device(device):
59+
if hf_download:
60+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors")
61+
tensors = {}
62+
with safe_open(ckpt_shard_path, framework="pt") as f:
63+
for k in f.keys():
64+
tensors[k] = torch2jax(f.get_tensor(k))
65+
66+
flax_state_dict = {}
67+
cpu = jax.local_devices(backend="cpu")[0]
68+
flattened_dict = flatten_dict(eval_shapes)
69+
# turn all block numbers to strings just for matching weights.
70+
# Later they will be turned back to ints.
71+
random_flax_state_dict = {}
72+
for key in flattened_dict:
73+
string_tuple = tuple([str(item) for item in key])
74+
random_flax_state_dict[string_tuple] = flattened_dict[key]
75+
for pt_key, tensor in tensors.items():
76+
renamed_pt_key = rename_key(pt_key)
77+
78+
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)
79+
80+
pt_tuple_key = tuple(renamed_pt_key.split("."))
81+
82+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
83+
flax_key = rename_for_nnx(flax_key)
84+
flax_key = _tuple_str_to_int(flax_key)
85+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
86+
validate_flax_state_dict(eval_shapes, flax_state_dict)
87+
flax_state_dict = unflatten_dict(flax_state_dict)
88+
del tensors
89+
jax.clear_caches()
90+
return flax_state_dict
3091

3192
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
3293
device = jax.devices(device)[0]
@@ -48,25 +109,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
48109
for pt_key, tensor in loaded_state_dict.items():
49110
tensor = torch2jax(tensor)
50111
renamed_pt_key = rename_key(pt_key)
51-
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
52-
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
53-
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
54-
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
55-
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
56-
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
57-
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")
58-
59-
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
60-
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
61-
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
62-
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
63-
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
64-
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
65-
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
66-
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
67-
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
68-
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
69-
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
112+
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)
70113

71114
pt_tuple_key = tuple(renamed_pt_key.split("."))
72115

@@ -85,6 +128,8 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
85128

86129
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
87130
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
131+
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
132+
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
88133
else:
89134
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
90135

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28-
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
28+
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2929

3030

3131
def string_to_bool(s: str) -> bool:
@@ -118,7 +118,7 @@ def wan_init(raw_keys):
118118
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119119
if transformer_pretrained_model_name_or_path == "":
120120
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121-
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
121+
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
122122
# Set correct parameters for CausVid in case of user error.
123123
raw_keys["guidance_scale"] = 1.0
124124
num_inference_steps = raw_keys["num_inference_steps"]

0 commit comments

Comments
 (0)