@@ -137,10 +137,11 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
137137 else :
138138 return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
139139
140+
140141def load_base_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
141142 device = jax .devices (device )[0 ]
142- subfolder = "transformer"
143- filename = "diffusion_pytorch_model.safetensors.index.json"
143+ subfolder = "transformer"
144+ filename = "diffusion_pytorch_model.safetensors.index.json"
144145 local_files = False
145146 if os .path .isdir (pretrained_model_name_or_path ):
146147 index_file_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
@@ -150,72 +151,72 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
150151 elif hf_download :
151152 # download the index file for sharded models.
152153 index_file_path = hf_hub_download (
153- pretrained_model_name_or_path , subfolder = subfolder , filename = filename ,
154+ pretrained_model_name_or_path ,
155+ subfolder = subfolder ,
156+ filename = filename ,
154157 )
155- with jax .default_device (device ):
156- # open the index file.
157- with open (index_file_path , "r" ) as f :
158- index_dict = json .load (f )
159- model_files = set ()
160- for key in index_dict ["weight_map" ].keys ():
161- model_files .add (index_dict ["weight_map" ][key ])
162-
163- model_files = list (model_files )
164- tensors = {}
165- for model_file in model_files :
166- if local_files :
167- ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
168- else :
169- ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
170- # now get all the filenames for the model that need downloading
171- max_logging .log (f"Load and port Wan 2.1 transformer on { device } " )
172-
173- if ckpt_shard_path is not None :
174- with safe_open (ckpt_shard_path , framework = "pt" ) as f :
175- for k in f .keys ():
176- tensors [k ] = torch2jax (f .get_tensor (k ))
177- flax_state_dict = {}
178- cpu = jax .local_devices (backend = "cpu" )[0 ]
179- flattened_dict = flatten_dict (eval_shapes )
180- # turn all block numbers to strings just for matching weights.
181- # Later they will be turned back to ints.
182- random_flax_state_dict = {}
183- for key in flattened_dict :
184- string_tuple = tuple ([str (item ) for item in key ])
185- random_flax_state_dict [string_tuple ] = flattened_dict [key ]
186- del flattened_dict
187- for pt_key , tensor in tensors .items ():
188- renamed_pt_key = rename_key (pt_key )
189- renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
190- renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
191- renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
192- renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
193- renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
194- pt_tuple_key = tuple (renamed_pt_key .split ("." ))
195-
196- flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
197- flax_key = rename_for_nnx (flax_key )
198- flax_key = _tuple_str_to_int (flax_key )
199- flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
200- validate_flax_state_dict (eval_shapes , flax_state_dict )
201- flax_state_dict = unflatten_dict (flax_state_dict )
202- del tensors
203- jax .clear_caches ()
204- return flax_state_dict
158+ with jax .default_device (device ):
159+ # open the index file.
160+ with open (index_file_path , "r" ) as f :
161+ index_dict = json .load (f )
162+ model_files = set ()
163+ for key in index_dict ["weight_map" ].keys ():
164+ model_files .add (index_dict ["weight_map" ][key ])
165+
166+ model_files = list (model_files )
167+ tensors = {}
168+ for model_file in model_files :
169+ if local_files :
170+ ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
171+ else :
172+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
173+ # now get all the filenames for the model that need downloading
174+ max_logging .log (f"Load and port Wan 2.1 transformer on { device } " )
175+
176+ if ckpt_shard_path is not None :
177+ with safe_open (ckpt_shard_path , framework = "pt" ) as f :
178+ for k in f .keys ():
179+ tensors [k ] = torch2jax (f .get_tensor (k ))
180+ flax_state_dict = {}
181+ cpu = jax .local_devices (backend = "cpu" )[0 ]
182+ flattened_dict = flatten_dict (eval_shapes )
183+ # turn all block numbers to strings just for matching weights.
184+ # Later they will be turned back to ints.
185+ random_flax_state_dict = {}
186+ for key in flattened_dict :
187+ string_tuple = tuple ([str (item ) for item in key ])
188+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
189+ del flattened_dict
190+ for pt_key , tensor in tensors .items ():
191+ renamed_pt_key = rename_key (pt_key )
192+ renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
193+ renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
194+ renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
195+ renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
196+ renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
197+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
198+
199+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
200+ flax_key = rename_for_nnx (flax_key )
201+ flax_key = _tuple_str_to_int (flax_key )
202+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
203+ validate_flax_state_dict (eval_shapes , flax_state_dict )
204+ flax_state_dict = unflatten_dict (flax_state_dict )
205+ del tensors
206+ jax .clear_caches ()
207+ return flax_state_dict
205208
206209
207210def load_wan_vae (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
208211 device = jax .devices (device )[0 ]
209- subfolder = "vae"
210- filename = "diffusion_pytorch_model.safetensors"
212+ subfolder = "vae"
213+ filename = "diffusion_pytorch_model.safetensors"
211214 if os .path .isdir (pretrained_model_name_or_path ):
212215 ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
213216 if not os .path .isfile (ckpt_path ):
214217 raise FileNotFoundError (f"File { ckpt_path } not found for local directory." )
215218 elif hf_download :
216- ckpt_path = hf_hub_download (
217- pretrained_model_name_or_path , subfolder = subfolder , filename = filename
218- )
219+ ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
219220 max_logging .log (f"Load and port Wan 2.1 VAE on { device } " )
220221 with jax .default_device (device ):
221222 if ckpt_path is not None :
0 commit comments