@@ -57,8 +57,10 @@ def rename_for_custom_trasformer(key):
5757 return renamed_pt_key
5858
5959
60- def load_fusionx_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
61- device = jax .devices (device )[0 ]
60+ def load_fusionx_transformer (
61+ pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
62+ ):
63+ device = jax .local_devices (backend = device )[0 ]
6264 with jax .default_device (device ):
6365 if hf_download :
6466 ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , filename = "Wan14BT2VFusioniX_fp16_.safetensors" )
@@ -97,7 +99,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
9799 if flax_key in flax_state_dict :
98100 new_tensor = flax_state_dict [flax_key ]
99101 else :
100- new_tensor = jnp .zeros ((40 ,) + flax_tensor .shape )
102+ new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape )
101103 flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
102104 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
103105 validate_flax_state_dict (eval_shapes , flax_state_dict )
@@ -107,8 +109,10 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
107109 return flax_state_dict
108110
109111
110- def load_causvid_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
111- device = jax .devices (device )[0 ]
112+ def load_causvid_transformer (
113+ pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
114+ ):
115+ device = jax .local_devices (backend = device )[0 ]
112116 with jax .default_device (device ):
113117 if hf_download :
114118 ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , filename = "causal_model.pt" )
@@ -145,7 +149,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
145149 if flax_key in flax_state_dict :
146150 new_tensor = flax_state_dict [flax_key ]
147151 else :
148- new_tensor = jnp .zeros ((40 ,) + flax_tensor .shape )
152+ new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape )
149153 flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
150154 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
151155 validate_flax_state_dict (eval_shapes , flax_state_dict )
@@ -155,18 +159,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
155159 return flax_state_dict
156160
157161
158- def load_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
162+ def load_wan_transformer (
163+ pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
164+ ):
159165
160166 if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
161- return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
167+ return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers )
162168 elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH :
163- return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
169+ return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers )
164170 else :
165- return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
171+ return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers )
166172
167173
168- def load_base_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
169- device = jax .devices (device )[0 ]
174+ def load_base_wan_transformer (
175+ pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True , num_layers : int = 40
176+ ):
177+ device = jax .local_devices (backend = device )[0 ]
170178 subfolder = "transformer"
171179 filename = "diffusion_pytorch_model.safetensors.index.json"
172180 local_files = False
@@ -237,7 +245,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
237245 if flax_key in flax_state_dict :
238246 new_tensor = flax_state_dict [flax_key ]
239247 else :
240- new_tensor = jnp .zeros ((40 ,) + flax_tensor .shape )
248+ new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape )
241249 flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
242250 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
243251 validate_flax_state_dict (eval_shapes , flax_state_dict )
0 commit comments