99from ..modeling_flax_pytorch_utils import (rename_key , rename_key_and_reshape_tensor , torch2jax , validate_flax_state_dict )
1010
1111CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
12+ WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
1213
1314
1415def _tuple_str_to_int (in_tuple ):
@@ -27,6 +28,66 @@ def rename_for_nnx(key):
2728 new_key = key [:- 1 ] + ("scale" ,)
2829 return new_key
2930
31+ def rename_for_custom_trasformer (key ):
32+ renamed_pt_key = key .replace ("model.diffusion_model." , "" )
33+
34+ renamed_pt_key = renamed_pt_key .replace ("head.modulation" , "scale_shift_table" )
35+ renamed_pt_key = renamed_pt_key .replace ("head.head" , "proj_out" )
36+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_0" , "condition_embedder.text_embedder.linear_1" )
37+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_2" , "condition_embedder.text_embedder.linear_2" )
38+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_0" , "condition_embedder.time_embedder.linear_1" )
39+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_2" , "condition_embedder.time_embedder.linear_2" )
40+ renamed_pt_key = renamed_pt_key .replace ("time_projection_1" , "condition_embedder.time_proj" )
41+
42+ renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
43+ renamed_pt_key = renamed_pt_key .replace ("self_attn" , "attn1" )
44+ renamed_pt_key = renamed_pt_key .replace ("cross_attn" , "attn2" )
45+ renamed_pt_key = renamed_pt_key .replace (".q." , ".query." )
46+ renamed_pt_key = renamed_pt_key .replace (".k." , ".key." )
47+ renamed_pt_key = renamed_pt_key .replace (".v." , ".value." )
48+ renamed_pt_key = renamed_pt_key .replace (".o." , ".proj_attn." )
49+ renamed_pt_key = renamed_pt_key .replace ("ffn_0" , "ffn.act_fn.proj" )
50+ renamed_pt_key = renamed_pt_key .replace ("ffn_2" , "ffn.proj_out" )
51+ renamed_pt_key = renamed_pt_key .replace (".modulation" , ".scale_shift_table" )
52+ renamed_pt_key = renamed_pt_key .replace ("norm3" , "norm2.layer_norm" )
53+
54+ return renamed_pt_key
55+
56+ def load_fusionx_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
57+ device = jax .devices (device )[0 ]
58+ with jax .default_device (device ):
59+ if hf_download :
60+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , filename = "Wan14BT2VFusioniX_fp16_.safetensors" )
61+ tensors = {}
62+ with safe_open (ckpt_shard_path , framework = "pt" ) as f :
63+ for k in f .keys ():
64+ tensors [k ] = torch2jax (f .get_tensor (k ))
65+
66+ flax_state_dict = {}
67+ cpu = jax .local_devices (backend = "cpu" )[0 ]
68+ flattened_dict = flatten_dict (eval_shapes )
69+ # turn all block numbers to strings just for matching weights.
70+ # Later they will be turned back to ints.
71+ random_flax_state_dict = {}
72+ for key in flattened_dict :
73+ string_tuple = tuple ([str (item ) for item in key ])
74+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
75+ for pt_key , tensor in tensors .items ():
76+ renamed_pt_key = rename_key (pt_key )
77+
78+ renamed_pt_key = rename_for_custom_trasformer (renamed_pt_key )
79+
80+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
81+
82+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
83+ flax_key = rename_for_nnx (flax_key )
84+ flax_key = _tuple_str_to_int (flax_key )
85+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
86+ validate_flax_state_dict (eval_shapes , flax_state_dict )
87+ flax_state_dict = unflatten_dict (flax_state_dict )
88+ del tensors
89+ jax .clear_caches ()
90+ return flax_state_dict
3091
3192def load_causvid_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
3293 device = jax .devices (device )[0 ]
@@ -48,25 +109,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
48109 for pt_key , tensor in loaded_state_dict .items ():
49110 tensor = torch2jax (tensor )
50111 renamed_pt_key = rename_key (pt_key )
51- renamed_pt_key = renamed_pt_key .replace ("head.modulation" , "scale_shift_table" )
52- renamed_pt_key = renamed_pt_key .replace ("head.head" , "proj_out" )
53- renamed_pt_key = renamed_pt_key .replace ("text_embedding_0" , "condition_embedder.text_embedder.linear_1" )
54- renamed_pt_key = renamed_pt_key .replace ("text_embedding_2" , "condition_embedder.text_embedder.linear_2" )
55- renamed_pt_key = renamed_pt_key .replace ("time_embedding_0" , "condition_embedder.time_embedder.linear_1" )
56- renamed_pt_key = renamed_pt_key .replace ("time_embedding_2" , "condition_embedder.time_embedder.linear_2" )
57- renamed_pt_key = renamed_pt_key .replace ("time_projection_1" , "condition_embedder.time_proj" )
58-
59- renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
60- renamed_pt_key = renamed_pt_key .replace ("self_attn" , "attn1" )
61- renamed_pt_key = renamed_pt_key .replace ("cross_attn" , "attn2" )
62- renamed_pt_key = renamed_pt_key .replace (".q." , ".query." )
63- renamed_pt_key = renamed_pt_key .replace (".k." , ".key." )
64- renamed_pt_key = renamed_pt_key .replace (".v." , ".value." )
65- renamed_pt_key = renamed_pt_key .replace (".o." , ".proj_attn." )
66- renamed_pt_key = renamed_pt_key .replace ("ffn_0" , "ffn.act_fn.proj" )
67- renamed_pt_key = renamed_pt_key .replace ("ffn_2" , "ffn.proj_out" )
68- renamed_pt_key = renamed_pt_key .replace (".modulation" , ".scale_shift_table" )
69- renamed_pt_key = renamed_pt_key .replace ("norm3" , "norm2.layer_norm" )
112+ renamed_pt_key = rename_for_custom_trasformer (renamed_pt_key )
70113
71114 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
72115
@@ -85,6 +128,8 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
85128
86129 if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
87130 return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
131+ elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH :
132+ return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
88133 else :
89134 return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
90135
0 commit comments