99from ..modeling_flax_pytorch_utils import (rename_key , rename_key_and_reshape_tensor , torch2jax , validate_flax_state_dict )
1010
1111CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
12+ WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
1213
1314
1415def _tuple_str_to_int (in_tuple ):
@@ -28,6 +29,69 @@ def rename_for_nnx(key):
2829 return new_key
2930
3031
32+ def rename_for_custom_trasformer (key ):
33+ renamed_pt_key = key .replace ("model.diffusion_model." , "" )
34+
35+ renamed_pt_key = renamed_pt_key .replace ("head.modulation" , "scale_shift_table" )
36+ renamed_pt_key = renamed_pt_key .replace ("head.head" , "proj_out" )
37+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_0" , "condition_embedder.text_embedder.linear_1" )
38+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_2" , "condition_embedder.text_embedder.linear_2" )
39+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_0" , "condition_embedder.time_embedder.linear_1" )
40+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_2" , "condition_embedder.time_embedder.linear_2" )
41+ renamed_pt_key = renamed_pt_key .replace ("time_projection_1" , "condition_embedder.time_proj" )
42+
43+ renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
44+ renamed_pt_key = renamed_pt_key .replace ("self_attn" , "attn1" )
45+ renamed_pt_key = renamed_pt_key .replace ("cross_attn" , "attn2" )
46+ renamed_pt_key = renamed_pt_key .replace (".q." , ".query." )
47+ renamed_pt_key = renamed_pt_key .replace (".k." , ".key." )
48+ renamed_pt_key = renamed_pt_key .replace (".v." , ".value." )
49+ renamed_pt_key = renamed_pt_key .replace (".o." , ".proj_attn." )
50+ renamed_pt_key = renamed_pt_key .replace ("ffn_0" , "ffn.act_fn.proj" )
51+ renamed_pt_key = renamed_pt_key .replace ("ffn_2" , "ffn.proj_out" )
52+ renamed_pt_key = renamed_pt_key .replace (".modulation" , ".scale_shift_table" )
53+ renamed_pt_key = renamed_pt_key .replace ("norm3" , "norm2.layer_norm" )
54+
55+ return renamed_pt_key
56+
57+
58+ def load_fusionx_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
59+ device = jax .devices (device )[0 ]
60+ with jax .default_device (device ):
61+ if hf_download :
62+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , filename = "Wan14BT2VFusioniX_fp16_.safetensors" )
63+ tensors = {}
64+ with safe_open (ckpt_shard_path , framework = "pt" ) as f :
65+ for k in f .keys ():
66+ tensors [k ] = torch2jax (f .get_tensor (k ))
67+
68+ flax_state_dict = {}
69+ cpu = jax .local_devices (backend = "cpu" )[0 ]
70+ flattened_dict = flatten_dict (eval_shapes )
71+ # turn all block numbers to strings just for matching weights.
72+ # Later they will be turned back to ints.
73+ random_flax_state_dict = {}
74+ for key in flattened_dict :
75+ string_tuple = tuple ([str (item ) for item in key ])
76+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
77+ for pt_key , tensor in tensors .items ():
78+ renamed_pt_key = rename_key (pt_key )
79+
80+ renamed_pt_key = rename_for_custom_trasformer (renamed_pt_key )
81+
82+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
83+
84+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
85+ flax_key = rename_for_nnx (flax_key )
86+ flax_key = _tuple_str_to_int (flax_key )
87+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
88+ validate_flax_state_dict (eval_shapes , flax_state_dict )
89+ flax_state_dict = unflatten_dict (flax_state_dict )
90+ del tensors
91+ jax .clear_caches ()
92+ return flax_state_dict
93+
94+
3195def load_causvid_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
3296 device = jax .devices (device )[0 ]
3397 with jax .default_device (device ):
@@ -48,25 +112,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
48112 for pt_key , tensor in loaded_state_dict .items ():
49113 tensor = torch2jax (tensor )
50114 renamed_pt_key = rename_key (pt_key )
51- renamed_pt_key = renamed_pt_key .replace ("head.modulation" , "scale_shift_table" )
52- renamed_pt_key = renamed_pt_key .replace ("head.head" , "proj_out" )
53- renamed_pt_key = renamed_pt_key .replace ("text_embedding_0" , "condition_embedder.text_embedder.linear_1" )
54- renamed_pt_key = renamed_pt_key .replace ("text_embedding_2" , "condition_embedder.text_embedder.linear_2" )
55- renamed_pt_key = renamed_pt_key .replace ("time_embedding_0" , "condition_embedder.time_embedder.linear_1" )
56- renamed_pt_key = renamed_pt_key .replace ("time_embedding_2" , "condition_embedder.time_embedder.linear_2" )
57- renamed_pt_key = renamed_pt_key .replace ("time_projection_1" , "condition_embedder.time_proj" )
58-
59- renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
60- renamed_pt_key = renamed_pt_key .replace ("self_attn" , "attn1" )
61- renamed_pt_key = renamed_pt_key .replace ("cross_attn" , "attn2" )
62- renamed_pt_key = renamed_pt_key .replace (".q." , ".query." )
63- renamed_pt_key = renamed_pt_key .replace (".k." , ".key." )
64- renamed_pt_key = renamed_pt_key .replace (".v." , ".value." )
65- renamed_pt_key = renamed_pt_key .replace (".o." , ".proj_attn." )
66- renamed_pt_key = renamed_pt_key .replace ("ffn_0" , "ffn.act_fn.proj" )
67- renamed_pt_key = renamed_pt_key .replace ("ffn_2" , "ffn.proj_out" )
68- renamed_pt_key = renamed_pt_key .replace (".modulation" , ".scale_shift_table" )
69- renamed_pt_key = renamed_pt_key .replace ("norm3" , "norm2.layer_norm" )
115+ renamed_pt_key = rename_for_custom_trasformer (renamed_pt_key )
70116
71117 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
72118
@@ -85,6 +131,8 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
85131
86132 if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
87133 return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
134+ elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH :
135+ return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
88136 else :
89137 return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
90138
0 commit comments