@@ -177,6 +177,26 @@ def load_causvid_transformer(
177177 return flax_state_dict
178178
179179
180+ def load_wan_transformer (
181+ pretrained_model_name_or_path : str ,
182+ eval_shapes : dict ,
183+ device : str ,
184+ hf_download : bool = True ,
185+ num_layers : int = 40 ,
186+ scan_layers : bool = True ,
187+ subfolder : str = "" ,
188+ ):
189+
190+ if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
191+ return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers )
192+ elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH :
193+ return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers )
194+ else :
195+ return load_base_wan_transformer (
196+ pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers , subfolder
197+ )
198+
199+
180200def load_base_wan_transformer (
181201 pretrained_model_name_or_path : str ,
182202 eval_shapes : dict ,
@@ -186,7 +206,6 @@ def load_base_wan_transformer(
186206 scan_layers : bool = True ,
187207 subfolder : str = "" ,
188208):
189- print (f"\n === DEBUG START: Loading Transformer from { pretrained_model_name_or_path } ===" )
190209 device = jax .local_devices (backend = device )[0 ]
191210 filename = "diffusion_pytorch_model.safetensors.index.json"
192211 local_files = False
@@ -196,13 +215,14 @@ def load_base_wan_transformer(
196215 raise FileNotFoundError (f"File { index_file_path } not found for local directory." )
197216 local_files = True
198217 elif hf_download :
218+ # download the index file for sharded models.
199219 index_file_path = hf_hub_download (
200220 pretrained_model_name_or_path ,
201221 subfolder = subfolder ,
202222 filename = filename ,
203223 )
204-
205224 with jax .default_device (device ):
225+ # open the index file.
206226 with open (index_file_path , "r" ) as f :
207227 index_dict = json .load (f )
208228 model_files = set ()
@@ -211,68 +231,37 @@ def load_base_wan_transformer(
211231
212232 model_files = list (model_files )
213233 tensors = {}
214- print (f"=== DEBUG: Loading { len (model_files )} shard files... ===" )
215-
216234 for model_file in model_files :
217235 if local_files :
218236 ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
219237 else :
220238 ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
221-
239+ # now get all the filenames for the model that need downloading
222240 max_logging .log (f"Load and port { pretrained_model_name_or_path } { subfolder } on { device } " )
223241
224242 if ckpt_shard_path is not None :
225243 with safe_open (ckpt_shard_path , framework = "pt" ) as f :
226244 for k in f .keys ():
227245 tensors [k ] = torch2jax (f .get_tensor (k ))
228-
229- print (f"=== DEBUG: Total tensors loaded: { len (tensors )} ===" )
230-
231246 flax_state_dict = {}
232247 cpu = jax .local_devices (backend = "cpu" )[0 ]
233248 flattened_dict = flatten_dict (eval_shapes )
249+ # turn all block numbers to strings just for matching weights.
250+ # Later they will be turned back to ints.
234251 random_flax_state_dict = {}
235252 for key in flattened_dict :
236253 string_tuple = tuple ([str (item ) for item in key ])
237254 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
238255 del flattened_dict
239-
240- # --- 1. Initialize Buffer ---
241256 norm_added_q_buffer = {}
242- norm_added_q_debug_count = 0
243-
244- print ("=== DEBUG: Starting Tensor Loop ===" )
245257 for pt_key , tensor in tensors .items ():
246258 renamed_pt_key = rename_key (pt_key )
247-
248- # --- 2. Buffer 'norm_added_q' Logic ---
249259 if "norm_added_q" in pt_key :
250- norm_added_q_debug_count += 1
251- # Print the first 3 to confirm we are hitting this block
252- if norm_added_q_debug_count <= 3 :
253- print (f"DEBUG: Catching norm_added_q key: { pt_key } " )
254-
255- try :
256- parts = pt_key .split ("." )
257- # Robust parsing: Find the part that is a digit
258- block_idx = - 1
259- for part in parts :
260- if part .isdigit ():
261- block_idx = int (part )
262- break
263-
264- if block_idx == - 1 :
265- print (f"DEBUG ERROR: Could not find block index in { pt_key } " )
266- continue
267-
268- tensor = tensor .T
269- norm_added_q_buffer [block_idx ] = tensor
270- except Exception as e :
271- print (f"DEBUG EXCEPTION parsing { pt_key } : { e } " )
272-
273- continue # SKIP rest of loop
274-
275- # --- 3. Image Embedder Logic ---
260+ parts = pt_key .split ("." )
261+ block_idx = int (parts [1 ])
262+ tensor = tensor .T
263+ norm_added_q_buffer [block_idx ] = tensor
264+ continue
276265 if "image_embedder" in renamed_pt_key :
277266 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
278267 "net.2" in renamed_pt_key or "net_2" in renamed_pt_key :
@@ -292,53 +281,22 @@ def load_base_wan_transformer(
292281 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
293282 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
294283 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
295-
296- # --- 4. Global Replacements ---
297284 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
298285 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
299286 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
300287 renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
301288 renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
302-
303289 if "norm2.layer_norm" not in renamed_pt_key :
304- renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
305-
290+ renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
306291 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
307292 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
308293 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
309-
310- # --- 5. Final Stack & Insert ---
311- print (f"=== DEBUG: Loop Finished. norm_added_q count: { len (norm_added_q_buffer )} ===" )
312-
313294 if norm_added_q_buffer :
314- sorted_indices = sorted (norm_added_q_buffer .keys ())
315- print (f"DEBUG: Stacking indices from { sorted_indices [0 ]} to { sorted_indices [- 1 ]} " )
316-
317- sorted_tensors = [norm_added_q_buffer [i ] for i in sorted_indices ]
295+ sorted_tensors = [norm_added_q_buffer [i ] for i in sorted (norm_added_q_buffer .keys ())]
318296 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
319- print (f"DEBUG: Final Stacked Shape: { stacked_tensor .shape } " )
320-
321- # INSERT AND VERIFY
322297 final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
323298 flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
324-
325- if final_key in flax_state_dict :
326- print (f"DEBUG: SUCCESS - Key { final_key } exists in dict." )
327- else :
328- print (f"DEBUG: CRITICAL FAILURE - Key insertion failed." )
329- else :
330- print ("DEBUG: FAILURE - Buffer is EMPTY. No norm_added_q found!" )
331-
332- print ("=== DEBUG: Starting Validation ===" )
333- # Print what the validator EXPECTS vs what we HAVE
334- expected_keys = set (flatten_dict (eval_shapes ).keys ())
335- actual_keys = set (flax_state_dict .keys ())
336-
337- missing = expected_keys - actual_keys
338- print (f"DEBUG: Missing Keys Count: { len (missing )} " )
339- if len (missing ) > 0 and len (missing ) < 10 :
340- print (f"DEBUG: Missing Keys: { missing } " )
341-
299+
342300 validate_flax_state_dict (eval_shapes , flax_state_dict )
343301 flax_state_dict = unflatten_dict (flax_state_dict )
344302 del tensors
@@ -404,4 +362,4 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
404362 else :
405363 raise FileNotFoundError (f"Path { ckpt_path } was not found" )
406364
407- return flax_state_dict
365+ return flax_state_dict
0 commit comments