@@ -196,6 +196,43 @@ def load_wan_transformer(
196196 pretrained_model_name_or_path , eval_shapes , device , hf_download , num_layers , scan_layers , subfolder
197197 )
198198
199+ def apply_turbo_scaling (params ):
200+ """
201+ Recursively traverses the unflattened state dict to find 'query' and 'key'
202+ layers and scales their kernels by 1/sqrt(2).
203+ """
204+ # Scale factor: 1/sqrt(2) ≈ 0.707
205+ scale_factor = 1.0 / (2 ** 0.5 )
206+
207+ # Counter to verify we actually hit the tensors
208+ scaled_count = 0
209+
210+ def _recursive_walk (d , path_prefix = "" ):
211+ nonlocal scaled_count
212+ # Iterate over a copy of keys to be safe, though we modify values in place
213+ for key , value in d .items ():
214+
215+ # 1. Target Identification: Is this a Query or Key layer?
216+ # We look for dicts named 'query' or 'key' that contain a 'kernel'
217+ if key in ['query' , 'key' ] and isinstance (value , dict ) and 'kernel' in value :
218+ # Apply the scale
219+ original_shape = value ['kernel' ].shape
220+ value ['kernel' ] = value ['kernel' ] * scale_factor
221+ scaled_count += 1
222+ print (f"⚡ Turbo Scaled: { path_prefix } .{ key } .kernel | Shape: { original_shape } " )
223+
224+ # 2. Recursion: If it's a container (like 'blocks' or 'attn1'), dive in.
225+ elif isinstance (value , dict ):
226+ _recursive_walk (value , path_prefix = f"{ path_prefix } .{ key } " if path_prefix else key )
227+
228+ print ("⚡ Starting Recursive Turbo Scaling..." )
229+ _recursive_walk (params )
230+
231+ if scaled_count == 0 :
232+ raise ValueError ("❌ Turbo Scaling Failed: No 'query' or 'key' kernels found! Check dictionary structure." )
233+
234+ print (f"⚡ DONE. Scaled { scaled_count } tensors successfully." )
235+ return params
199236
200237def load_base_wan_transformer (
201238 pretrained_model_name_or_path : str ,
@@ -269,6 +306,7 @@ def load_base_wan_transformer(
269306 flax_state_dict = unflatten_dict (flax_state_dict )
270307 del tensors
271308 jax .clear_caches ()
309+ flax_state_dict = apply_turbo_scaling (flax_state_dict )
272310 return flax_state_dict
273311
274312
0 commit comments