@@ -253,8 +253,15 @@ def load_base_wan_transformer(
253253 string_tuple = tuple ([str (item ) for item in key ])
254254 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
255255 del flattened_dict
256+ norm_added_q_buffer = {}
256257 for pt_key , tensor in tensors .items ():
257258 renamed_pt_key = rename_key (pt_key )
259+ if "norm_added_q" in pt_key :
260+ parts = pt_key .split ("." )
261+ block_idx = int (parts [1 ])
262+ tensor = tensor .T
263+ norm_added_q_buffer [block_idx ] = tensor
264+ continue
258265 if "norm_added_q" in pt_key :
259266 debug_original = renamed_pt_key
260267 if "image_embedder" in renamed_pt_key :
@@ -276,13 +283,6 @@ def load_base_wan_transformer(
276283 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
277284 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
278285 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
279-
280- if "norm_added_q" in renamed_pt_key :
281- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
282- tensor = tensor .T
283- renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
284-
285-
286286 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
287287 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
288288 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
@@ -302,6 +302,11 @@ def load_base_wan_transformer(
302302 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
303303 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
304304 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
305+ if norm_added_q_buffer :
306+ sorted_tensors = [norm_added_q_buffer [i ] for i in sorted (norm_added_q_buffer .keys ())]
307+ stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
308+ final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
309+ flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
305310
306311 validate_flax_state_dict (eval_shapes , flax_state_dict )
307312 flax_state_dict = unflatten_dict (flax_state_dict )
0 commit comments