@@ -177,26 +177,6 @@ def load_causvid_transformer(
177177 return flax_state_dict
178178
179179
180- def load_wan_transformer (
181- pretrained_model_name_or_path : str ,
182- eval_shapes : dict ,
183- device : str ,
184- hf_download : bool = True ,
185- num_layers : int = 40 ,
186- scan_layers : bool = True ,
187- subfolder : str = "" ,
188- ):
189-
190- if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
191- return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers )
192- elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH :
193- return load_fusionx_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers )
194- else :
195- return load_base_wan_transformer (
196- pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers , subfolder
197- )
198-
199-
200180def load_base_wan_transformer (
201181 pretrained_model_name_or_path : str ,
202182 eval_shapes : dict ,
@@ -206,6 +186,7 @@ def load_base_wan_transformer(
206186 scan_layers : bool = True ,
207187 subfolder : str = "" ,
208188):
189+ print (f"\n === DEBUG START: Loading Transformer from { pretrained_model_name_or_path } ===" )
209190 device = jax .local_devices (backend = device )[0 ]
210191 filename = "diffusion_pytorch_model.safetensors.index.json"
211192 local_files = False
@@ -215,14 +196,13 @@ def load_base_wan_transformer(
215196 raise FileNotFoundError (f"File { index_file_path } not found for local directory." )
216197 local_files = True
217198 elif hf_download :
218- # download the index file for sharded models.
219199 index_file_path = hf_hub_download (
220200 pretrained_model_name_or_path ,
221201 subfolder = subfolder ,
222202 filename = filename ,
223203 )
204+
224205 with jax .default_device (device ):
225- # open the index file.
226206 with open (index_file_path , "r" ) as f :
227207 index_dict = json .load (f )
228208 model_files = set ()
@@ -231,37 +211,68 @@ def load_base_wan_transformer(
231211
232212 model_files = list (model_files )
233213 tensors = {}
214+ print (f"=== DEBUG: Loading { len (model_files )} shard files... ===" )
215+
234216 for model_file in model_files :
235217 if local_files :
236218 ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
237219 else :
238220 ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
239- # now get all the filenames for the model that need downloading
221+
240222 max_logging .log (f"Load and port { pretrained_model_name_or_path } { subfolder } on { device } " )
241223
242224 if ckpt_shard_path is not None :
243225 with safe_open (ckpt_shard_path , framework = "pt" ) as f :
244226 for k in f .keys ():
245227 tensors [k ] = torch2jax (f .get_tensor (k ))
228+
229+ print (f"=== DEBUG: Total tensors loaded: { len (tensors )} ===" )
230+
246231 flax_state_dict = {}
247232 cpu = jax .local_devices (backend = "cpu" )[0 ]
248233 flattened_dict = flatten_dict (eval_shapes )
249- # turn all block numbers to strings just for matching weights.
250- # Later they will be turned back to ints.
251234 random_flax_state_dict = {}
252235 for key in flattened_dict :
253236 string_tuple = tuple ([str (item ) for item in key ])
254237 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
255238 del flattened_dict
239+
240+ # --- 1. Initialize Buffer ---
256241 norm_added_q_buffer = {}
242+ norm_added_q_debug_count = 0
243+
244+ print ("=== DEBUG: Starting Tensor Loop ===" )
257245 for pt_key , tensor in tensors .items ():
258246 renamed_pt_key = rename_key (pt_key )
247+
248+ # --- 2. Buffer 'norm_added_q' Logic ---
259249 if "norm_added_q" in pt_key :
260- parts = pt_key .split ("." )
261- block_idx = int (parts [1 ])
262- tensor = tensor .T
263- norm_added_q_buffer [block_idx ] = tensor
264- continue
250+ norm_added_q_debug_count += 1
251+ # Print the first 3 to confirm we are hitting this block
252+ if norm_added_q_debug_count <= 3 :
253+ print (f"DEBUG: Catching norm_added_q key: { pt_key } " )
254+
255+ try :
256+ parts = pt_key .split ("." )
257+ # Robust parsing: Find the part that is a digit
258+ block_idx = - 1
259+ for part in parts :
260+ if part .isdigit ():
261+ block_idx = int (part )
262+ break
263+
264+ if block_idx == - 1 :
265+ print (f"DEBUG ERROR: Could not find block index in { pt_key } " )
266+ continue
267+
268+ tensor = tensor .T
269+ norm_added_q_buffer [block_idx ] = tensor
270+ except Exception as e :
271+ print (f"DEBUG EXCEPTION parsing { pt_key } : { e } " )
272+
273+ continue # SKIP rest of loop
274+
275+ # --- 3. Image Embedder Logic ---
265276 if "image_embedder" in renamed_pt_key :
266277 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
267278 "net.2" in renamed_pt_key or "net_2" in renamed_pt_key :
@@ -281,22 +292,53 @@ def load_base_wan_transformer(
281292 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
282293 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
283294 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
295+
296+ # --- 4. Global Replacements ---
284297 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
285298 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
286299 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
287300 renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
288301 renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
302+
289303 if "norm2.layer_norm" not in renamed_pt_key :
290- renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
304+ renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
305+
291306 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
292307 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
293308 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
309+
310+ # --- 5. Final Stack & Insert ---
311+ print (f"=== DEBUG: Loop Finished. norm_added_q count: { len (norm_added_q_buffer )} ===" )
312+
294313 if norm_added_q_buffer :
295- sorted_tensors = [norm_added_q_buffer [i ] for i in sorted (norm_added_q_buffer .keys ())]
314+ sorted_indices = sorted (norm_added_q_buffer .keys ())
315+ print (f"DEBUG: Stacking indices from { sorted_indices [0 ]} to { sorted_indices [- 1 ]} " )
316+
317+ sorted_tensors = [norm_added_q_buffer [i ] for i in sorted_indices ]
296318 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
319+ print (f"DEBUG: Final Stacked Shape: { stacked_tensor .shape } " )
320+
321+ # INSERT AND VERIFY
297322 final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
298323 flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
299-
324+
325+ if final_key in flax_state_dict :
326+ print (f"DEBUG: SUCCESS - Key { final_key } exists in dict." )
327+ else :
328+ print (f"DEBUG: CRITICAL FAILURE - Key insertion failed." )
329+ else :
330+ print ("DEBUG: FAILURE - Buffer is EMPTY. No norm_added_q found!" )
331+
332+ print ("=== DEBUG: Starting Validation ===" )
333+ # Print what the validator EXPECTS vs what we HAVE
334+ expected_keys = set (flatten_dict (eval_shapes ).keys ())
335+ actual_keys = set (flax_state_dict .keys ())
336+
337+ missing = expected_keys - actual_keys
338+ print (f"DEBUG: Missing Keys Count: { len (missing )} " )
339+ if len (missing ) > 0 and len (missing ) < 10 :
340+ print (f"DEBUG: Missing Keys: { missing } " )
341+
300342 validate_flax_state_dict (eval_shapes , flax_state_dict )
301343 flax_state_dict = unflatten_dict (flax_state_dict )
302344 del tensors
0 commit comments