2424from safetensors import safe_open
2525from flax .traverse_util import unflatten_dict , flatten_dict
2626from ..modeling_flax_pytorch_utils import (rename_key , rename_key_and_reshape_tensor , torch2jax , validate_flax_state_dict )
27- from ...common_types import WAN_MODEL
2827
2928CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
3029WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
@@ -73,8 +72,35 @@ def rename_for_custom_trasformer(key):
7372 return renamed_pt_key
7473
7574
75+ def get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers ):
76+ if scan_layers :
77+ if "blocks" in pt_tuple_key :
78+ new_key = ("blocks" ,) + pt_tuple_key [2 :]
79+ block_index = int (pt_tuple_key [1 ])
80+ pt_tuple_key = new_key
81+
82+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
83+
84+ flax_key = rename_for_nnx (flax_key )
85+ flax_key = _tuple_str_to_int (flax_key )
86+
87+ if scan_layers :
88+ if "blocks" in flax_key :
89+ if flax_key in flax_state_dict :
90+ new_tensor = flax_state_dict [flax_key ]
91+ else :
92+ new_tensor = jnp .zeros ((40 ,) + flax_tensor .shape )
93+ flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
94+ return flax_key , flax_tensor
95+
96+
7697def load_fusionx_transformer (
77- pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
98+ pretrained_model_name_or_path : str ,
99+ eval_shapes : dict ,
100+ device : str ,
101+ hf_download : bool = True ,
102+ num_layers : int = 40 ,
103+ scan_layers : bool = True ,
78104):
79105 device = jax .local_devices (backend = device )[0 ]
80106 with jax .default_device (device ):
@@ -101,23 +127,9 @@ def load_fusionx_transformer(
101127
102128 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
103129
104- if "blocks" in pt_tuple_key :
105- new_key = ("blocks" ,) + pt_tuple_key [2 :]
106- block_index = int (pt_tuple_key [1 ])
107- pt_tuple_key = new_key
108- flax_key , flax_tensor = rename_key_and_reshape_tensor (
109- pt_tuple_key , tensor , random_flax_state_dict , model_type = WAN_MODEL
110- )
111- flax_key = rename_for_nnx (flax_key )
112- flax_key = _tuple_str_to_int (flax_key )
113-
114- if "blocks" in flax_key :
115- if flax_key in flax_state_dict :
116- new_tensor = flax_state_dict [flax_key ]
117- else :
118- new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape )
119- flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
130+ flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
120131 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
132+
121133 validate_flax_state_dict (eval_shapes , flax_state_dict )
122134 flax_state_dict = unflatten_dict (flax_state_dict )
123135 del tensors
@@ -126,7 +138,12 @@ def load_fusionx_transformer(
126138
127139
128140def load_causvid_transformer (
129- pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
141+ pretrained_model_name_or_path : str ,
142+ eval_shapes : dict ,
143+ device : str ,
144+ hf_download : bool = True ,
145+ num_layers : int = 40 ,
146+ scan_layers : bool = True ,
130147):
131148 device = jax .local_devices (backend = device )[0 ]
132149 with jax .default_device (device ):
@@ -150,24 +167,9 @@ def load_causvid_transformer(
150167 renamed_pt_key = rename_for_custom_trasformer (renamed_pt_key )
151168
152169 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
153-
154- if "blocks" in pt_tuple_key :
155- new_key = ("blocks" ,) + pt_tuple_key [2 :]
156- block_index = int (pt_tuple_key [1 ])
157- pt_tuple_key = new_key
158- flax_key , flax_tensor = rename_key_and_reshape_tensor (
159- pt_tuple_key , tensor , random_flax_state_dict , model_type = WAN_MODEL
160- )
161- flax_key = rename_for_nnx (flax_key )
162- flax_key = _tuple_str_to_int (flax_key )
163-
164- if "blocks" in flax_key :
165- if flax_key in flax_state_dict :
166- new_tensor = flax_state_dict [flax_key ]
167- else :
168- new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape )
169- flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
170+ flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
170171 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
172+
171173 validate_flax_state_dict (eval_shapes , flax_state_dict )
172174 flax_state_dict = unflatten_dict (flax_state_dict )
173175 del tensors
@@ -176,19 +178,31 @@ def load_causvid_transformer(
176178
177179
178180def load_wan_transformer (
179- pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
181+ pretrained_model_name_or_path : str ,
182+ eval_shapes : dict ,
183+ device : str ,
184+ hf_download : bool = True ,
185+ num_layers : int = 40 ,
186+ scan_layers : bool = True ,
180187):
181188
182189 if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
183- return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers )
190+ return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers )
184191 elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH :
185- return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers )
192+ return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers )
186193 else :
187- return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers )
194+ return load_base_wan_transformer (
195+ pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers
196+ )
188197
189198
190199def load_base_wan_transformer (
191- pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
200+ pretrained_model_name_or_path : str ,
201+ eval_shapes : dict ,
202+ device : str ,
203+ hf_download : bool = True ,
204+ num_layers : int = 40 ,
205+ scan_layers : bool = True ,
192206):
193207 device = jax .local_devices (backend = device )[0 ]
194208 subfolder = "transformer"
@@ -247,24 +261,9 @@ def load_base_wan_transformer(
247261 renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
248262 renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
249263 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
250-
251- if "blocks" in pt_tuple_key :
252- new_key = ("blocks" ,) + pt_tuple_key [2 :]
253- block_index = int (pt_tuple_key [1 ])
254- pt_tuple_key = new_key
255- flax_key , flax_tensor = rename_key_and_reshape_tensor (
256- pt_tuple_key , tensor , random_flax_state_dict , model_type = WAN_MODEL
257- )
258- flax_key = rename_for_nnx (flax_key )
259- flax_key = _tuple_str_to_int (flax_key )
260-
261- if "blocks" in flax_key :
262- if flax_key in flax_state_dict :
263- new_tensor = flax_state_dict [flax_key ]
264- else :
265- new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape )
266- flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
264+ flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
267265 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
266+
268267 validate_flax_state_dict (eval_shapes , flax_state_dict )
269268 flax_state_dict = unflatten_dict (flax_state_dict )
270269 del tensors
0 commit comments