@@ -215,14 +215,12 @@ def load_base_wan_transformer(
215215 raise FileNotFoundError (f"File { index_file_path } not found for local directory." )
216216 local_files = True
217217 elif hf_download :
218- # download the index file for sharded models.
219218 index_file_path = hf_hub_download (
220219 pretrained_model_name_or_path ,
221220 subfolder = subfolder ,
222221 filename = filename ,
223222 )
224223 with jax .default_device (device ):
225- # open the index file.
226224 with open (index_file_path , "r" ) as f :
227225 index_dict = json .load (f )
228226 model_files = set ()
@@ -236,37 +234,42 @@ def load_base_wan_transformer(
236234 ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
237235 else :
238236 ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
239- # now get all the filenames for the model that need downloading
237+
240238 max_logging .log (f"Load and port { pretrained_model_name_or_path } { subfolder } on { device } " )
241239
242240 if ckpt_shard_path is not None :
243241 with safe_open (ckpt_shard_path , framework = "pt" ) as f :
244242 for k in f .keys ():
245243 tensors [k ] = torch2jax (f .get_tensor (k ))
244+
246245 flax_state_dict = {}
247246 cpu = jax .local_devices (backend = "cpu" )[0 ]
248- flattened_eval_shapes = flatten_dict (eval_shapes )
249- # turn all block numbers to strings just for matching weights.
250- # Later they will be turned back to ints.
247+ flattened_dict = flatten_dict (eval_shapes )
251248 random_flax_state_dict = {}
252- for key in flattened_eval_shapes :
249+ for key in flattened_dict :
253250 string_tuple = tuple ([str (item ) for item in key ])
254- random_flax_state_dict [string_tuple ] = flattened_eval_shapes [key ]
255- # del flattened_dict
251+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
252+ del flattened_dict
253+
254+ # 1. Initialize buffer for norm_added_q
256255 norm_added_q_buffer = {}
256+
257257 for pt_key , tensor in tensors .items ():
258+ # 2. Robustly Intercept norm_added_q keys
258259 if "norm_added_q" in pt_key and "weight" in pt_key :
259260 parts = pt_key .split ("." )
260261 try :
262+ # Find the block index regardless of prefix (blocks.0 vs model.blocks.0)
261263 if "blocks" in parts :
262264 block_idx_loc = parts .index ("blocks" ) + 1
263265 block_idx = int (parts [block_idx_loc ])
264266 tensor = tensor .T
265267 norm_added_q_buffer [block_idx ] = tensor
266268 except Exception :
267- pass
269+ pass # Skip if unparseable
268270 continue
269-
271+
272+ # Standard processing
270273 renamed_pt_key = rename_key (pt_key )
271274 if "image_embedder" in renamed_pt_key :
272275 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
@@ -287,39 +290,36 @@ def load_base_wan_transformer(
287290 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
288291 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
289292 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
293+
290294 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
291295 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
292296 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
293297 renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
294298 renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
295299 if "norm2.layer_norm" not in renamed_pt_key :
296300 renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
301+
297302 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
298303 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
299304 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
305+
306+ # 3. Stack and Insert (Correct Key Name for RMSNorm is 'scale')
300307 if norm_added_q_buffer :
301308 sorted_keys = sorted (norm_added_q_buffer .keys ())
302309 sorted_tensors = [norm_added_q_buffer [i ] for i in sorted_keys ]
303310 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
304- final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
311+
312+ # 'scale' is the correct parameter name for Flax/NNX RMSNorm
313+ final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'scale' )
314+
305315 flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
306- print (f"DEBUG: Manually injected { final_key } into flax_state_dict" )
307- if final_key not in flattened_eval_shapes :
308- print (f"DEBUG: Key { final_key } missing in eval_shapes. Patching it now." )
309- shape_struct = jax .ShapeDtypeStruct (
310- shape = stacked_tensor .shape ,
311- dtype = stacked_tensor .dtype
312- )
313- flattened_eval_shapes [final_key ] = shape_struct
314- eval_shapes = unflatten_dict (flattened_eval_shapes )
315316
316317 validate_flax_state_dict (eval_shapes , flax_state_dict )
317318 flax_state_dict = unflatten_dict (flax_state_dict )
318319 del tensors
319320 jax .clear_caches ()
320321 return flax_state_dict
321322
322-
323323def load_wan_vae (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
324324 device = jax .devices (device )[0 ]
325325 subfolder = "vae"
0 commit comments