Skip to content

Commit 2453c1b

Browse files
committed
test turbo scaling for head_dim 256
1 parent e3e6107 commit 2453c1b

3 files changed

Lines changed: 57 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,14 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
7373
p_t, p_h, p_w = self.patch_size
7474
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
7575

76-
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim)
76+
is_turbo_mode = (self.attention_head_dim == 256)
77+
78+
# 2. Force the frequency calculation to use 128.
79+
# This preserves the original T/H/W split ratios (21/21/22)
80+
# instead of stretching them to (42/42/44).
81+
calc_dim = 128 if is_turbo_mode else self.attention_head_dim
82+
83+
freqs_split = get_frequencies(self.max_seq_len, self.theta, calc_dim)
7784

7885
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
7986
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
@@ -85,6 +92,15 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
8592
freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1]))
8693

8794
freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1)
95+
96+
# === TURBO ADAPTER TILING: START ===
97+
if is_turbo_mode:
98+
# We calculated frequencies for a 128-dim head.
99+
# We must duplicate them so the "Second Fused Head" (indices 128-255)
100+
# sees the exact same rotation as the "First Head" (indices 0-127).
101+
freqs_concat = jnp.concatenate([freqs_concat, freqs_concat], axis=-1)
102+
# === TURBO ADAPTER TILING: END ===
103+
88104
freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1))
89105
return freqs_final
90106

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,43 @@ def load_wan_transformer(
196196
pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder
197197
)
198198

199+
def apply_turbo_scaling(params):
200+
"""
201+
Recursively traverses the unflattened state dict to find 'query' and 'key'
202+
layers and scales their kernels by 1/sqrt(2).
203+
"""
204+
# Scale factor: 1/sqrt(2) ≈ 0.707
205+
scale_factor = 1.0 / (2 ** 0.5)
206+
207+
# Counter to verify we actually hit the tensors
208+
scaled_count = 0
209+
210+
def _recursive_walk(d, path_prefix=""):
211+
nonlocal scaled_count
212+
# Iterate over a copy of keys to be safe, though we modify values in place
213+
for key, value in d.items():
214+
215+
# 1. Target Identification: Is this a Query or Key layer?
216+
# We look for dicts named 'query' or 'key' that contain a 'kernel'
217+
if key in ['query', 'key'] and isinstance(value, dict) and 'kernel' in value:
218+
# Apply the scale
219+
original_shape = value['kernel'].shape
220+
value['kernel'] = value['kernel'] * scale_factor
221+
scaled_count += 1
222+
print(f"⚡ Turbo Scaled: {path_prefix}.{key}.kernel | Shape: {original_shape}")
223+
224+
# 2. Recursion: If it's a container (like 'blocks' or 'attn1'), dive in.
225+
elif isinstance(value, dict):
226+
_recursive_walk(value, path_prefix=f"{path_prefix}.{key}" if path_prefix else key)
227+
228+
print("⚡ Starting Recursive Turbo Scaling...")
229+
_recursive_walk(params)
230+
231+
if scaled_count == 0:
232+
raise ValueError("❌ Turbo Scaling Failed: No 'query' or 'key' kernels found! Check dictionary structure.")
233+
234+
print(f"⚡ DONE. Scaled {scaled_count} tensors successfully.")
235+
return params
199236

200237
def load_base_wan_transformer(
201238
pretrained_model_name_or_path: str,
@@ -269,6 +306,7 @@ def load_base_wan_transformer(
269306
flax_state_dict = unflatten_dict(flax_state_dict)
270307
del tensors
271308
jax.clear_caches()
309+
flax_state_dict = apply_turbo_scaling(flax_state_dict)
272310
return flax_state_dict
273311

274312

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
8787
vs.sharding_rules = logical_axis_rules
8888
return vs
8989

90-
9190
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
9291
def create_sharded_logical_transformer(
9392
devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = ""
@@ -116,6 +115,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
116115
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
117116
wan_config["scan_layers"] = config.scan_layers
118117
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes
118+
wan_config["num_attention_heads"] = 20
119+
wan_config["attention_head_dim"] = 256
119120

120121
# 2. eval_shape - will not use flops or create weights on device
121122
# thus not using HBM memory.

0 commit comments

Comments
 (0)