Skip to content

Commit 6c52603

Browse files
Fusion x wan (#198)
* add fusion x support. --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent 31dcb6c commit 6c52603

2 files changed

Lines changed: 72 additions & 21 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
1010

1111
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
12+
WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
1213

1314

1415
def _tuple_str_to_int(in_tuple):
@@ -28,6 +29,69 @@ def rename_for_nnx(key):
2829
return new_key
2930

3031

32+
def rename_for_custom_trasformer(key):
33+
renamed_pt_key = key.replace("model.diffusion_model.", "")
34+
35+
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
36+
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
37+
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
38+
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
39+
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
40+
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
41+
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")
42+
43+
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
44+
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
45+
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
46+
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
47+
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
48+
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
49+
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
50+
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
51+
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
52+
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
53+
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
54+
55+
return renamed_pt_key
56+
57+
58+
def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
59+
device = jax.devices(device)[0]
60+
with jax.default_device(device):
61+
if hf_download:
62+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors")
63+
tensors = {}
64+
with safe_open(ckpt_shard_path, framework="pt") as f:
65+
for k in f.keys():
66+
tensors[k] = torch2jax(f.get_tensor(k))
67+
68+
flax_state_dict = {}
69+
cpu = jax.local_devices(backend="cpu")[0]
70+
flattened_dict = flatten_dict(eval_shapes)
71+
# turn all block numbers to strings just for matching weights.
72+
# Later they will be turned back to ints.
73+
random_flax_state_dict = {}
74+
for key in flattened_dict:
75+
string_tuple = tuple([str(item) for item in key])
76+
random_flax_state_dict[string_tuple] = flattened_dict[key]
77+
for pt_key, tensor in tensors.items():
78+
renamed_pt_key = rename_key(pt_key)
79+
80+
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)
81+
82+
pt_tuple_key = tuple(renamed_pt_key.split("."))
83+
84+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
85+
flax_key = rename_for_nnx(flax_key)
86+
flax_key = _tuple_str_to_int(flax_key)
87+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
88+
validate_flax_state_dict(eval_shapes, flax_state_dict)
89+
flax_state_dict = unflatten_dict(flax_state_dict)
90+
del tensors
91+
jax.clear_caches()
92+
return flax_state_dict
93+
94+
3195
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
3296
device = jax.devices(device)[0]
3397
with jax.default_device(device):
@@ -48,25 +112,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
48112
for pt_key, tensor in loaded_state_dict.items():
49113
tensor = torch2jax(tensor)
50114
renamed_pt_key = rename_key(pt_key)
51-
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
52-
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
53-
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
54-
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
55-
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
56-
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
57-
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")
58-
59-
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
60-
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
61-
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
62-
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
63-
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
64-
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
65-
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
66-
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
67-
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
68-
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
69-
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
115+
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)
70116

71117
pt_tuple_key = tuple(renamed_pt_key.split("."))
72118

@@ -85,6 +131,8 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
85131

86132
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
87133
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
134+
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
135+
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
88136
else:
89137
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
90138

src/maxdiffusion/pyconfig.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28-
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
28+
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
2929

3030

3131
def string_to_bool(s: str) -> bool:
@@ -118,7 +118,10 @@ def wan_init(raw_keys):
118118
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119119
if transformer_pretrained_model_name_or_path == "":
120120
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121-
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
121+
elif (
122+
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
123+
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
124+
):
122125
# Set correct parameters for CausVid in case of user error.
123126
raw_keys["guidance_scale"] = 1.0
124127
num_inference_steps = raw_keys["num_inference_steps"]

0 commit comments

Comments
 (0)