@@ -215,12 +215,14 @@ def load_base_wan_transformer(
215215 raise FileNotFoundError (f"File { index_file_path } not found for local directory." )
216216 local_files = True
217217 elif hf_download :
218+ # download the index file for sharded models.
218219 index_file_path = hf_hub_download (
219220 pretrained_model_name_or_path ,
220221 subfolder = subfolder ,
221222 filename = filename ,
222223 )
223224 with jax .default_device (device ):
225+ # open the index file.
224226 with open (index_file_path , "r" ) as f :
225227 index_dict = json .load (f )
226228 model_files = set ()
@@ -234,7 +236,7 @@ def load_base_wan_transformer(
234236 ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
235237 else :
236238 ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
237-
239+ # now get all the filenames for the model that need downloading
238240 max_logging .log (f"Load and port { pretrained_model_name_or_path } { subfolder } on { device } " )
239241
240242 if ckpt_shard_path is not None :
@@ -245,31 +247,29 @@ def load_base_wan_transformer(
245247 flax_state_dict = {}
246248 cpu = jax .local_devices (backend = "cpu" )[0 ]
247249 flattened_dict = flatten_dict (eval_shapes )
250+ # turn all block numbers to strings just for matching weights.
251+ # Later they will be turned back to ints.
248252 random_flax_state_dict = {}
249253 for key in flattened_dict :
250254 string_tuple = tuple ([str (item ) for item in key ])
251255 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
252256 del flattened_dict
253257
254- # 1. Initialize buffer for norm_added_q
255258 norm_added_q_buffer = {}
256259
257260 for pt_key , tensor in tensors .items ():
258- # 2. Robustly Intercept norm_added_q keys
259261 if "norm_added_q" in pt_key and "weight" in pt_key :
260262 parts = pt_key .split ("." )
261263 try :
262- # Find the block index regardless of prefix (blocks.0 vs model.blocks.0)
263264 if "blocks" in parts :
264265 block_idx_loc = parts .index ("blocks" ) + 1
265266 block_idx = int (parts [block_idx_loc ])
266267 tensor = tensor .T
267268 norm_added_q_buffer [block_idx ] = tensor
268269 except Exception :
269- pass # Skip if unparseable
270+ pass
270271 continue
271272
272- # Standard processing
273273 renamed_pt_key = rename_key (pt_key )
274274 if "image_embedder" in renamed_pt_key :
275275 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
@@ -303,13 +303,10 @@ def load_base_wan_transformer(
303303 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
304304 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
305305
306- # 3. Stack and Insert (Correct Key Name for RMSNorm is 'scale')
307306 if norm_added_q_buffer :
308307 sorted_keys = sorted (norm_added_q_buffer .keys ())
309308 sorted_tensors = [norm_added_q_buffer [i ] for i in sorted_keys ]
310309 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
311-
312- # 'scale' is the correct parameter name for Flax/NNX RMSNorm
313310 final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'scale' )
314311
315312 flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
0 commit comments