Skip to content

Commit 77a43cf

Browse files
committed
experiment 1
1 parent d843dc0 commit 77a43cf

6 files changed

Lines changed: 102 additions & 17 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ num_eval_samples: 420
244244

245245
warmup_steps_fraction: 0.1
246246
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
247-
save_optimizer: False
247+
save_optimizer: True
248248

249249
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
250250
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
@@ -326,4 +326,10 @@ eval_data_dir: ""
326326
enable_generate_video_for_eval: False # This will increase the used TPU memory.
327327
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
328328

329-
enable_ssim: False
329+
enable_ssim: False
330+
331+
# Model surgery
332+
override_model_dims: True
333+
# If doubling the target_head_dim, then must halve the num_heads
334+
target_head_dim: 256
335+
target_num_heads: 20

src/maxdiffusion/generate_wan.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,12 @@ def run(config, pipeline=None, filename_prefix=""):
133133
print("seed: ", config.seed)
134134
model_key = config.model_name
135135

136-
checkpointer_lib = get_checkpointer(model_key)
137-
WanCheckpointer = checkpointer_lib.WanCheckpointer
138-
139-
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
140-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
141-
142136
if pipeline is None:
137+
checkpointer_lib = get_checkpointer(model_key)
138+
WanCheckpointer = checkpointer_lib.WanCheckpointer
139+
140+
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
141+
pipeline, _, _ = checkpoint_loader.load_checkpoint()
143142
pipeline_lib = get_pipeline(model_key)
144143
WanPipeline = pipeline_lib.WanPipeline
145144
pipeline = WanPipeline.from_pretrained(config)

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def get_1d_rotary_pos_embed(
225225
ntk_factor=1.0,
226226
freqs_dtype=jnp.float32,
227227
use_real: bool = True,
228+
original_dim: int = None,
228229
):
229230
"""
230231
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -235,7 +236,14 @@ def get_1d_rotary_pos_embed(
235236
pos = jnp.arange(pos)
236237

237238
theta = theta * ntk_factor
238-
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor
239+
240+
# If original_dim is provided, we use it as the denominator for the exponent.
241+
# For example, if we change the head_dim from 128 to 256, this ensures indices 0-127 generate the EXACT same frequencies they did during pre-training.
242+
# Indices 128-255 will simply continue that curve into lower frequencies.
243+
scale_dim = original_dim if original_dim is not None else dim
244+
245+
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / scale_dim)) / linear_factor
246+
239247
freqs = jnp.outer(pos, freqs)
240248
if use_real:
241249
# Flux

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,33 @@
3939
BlockSizes = common_types.BlockSizes
4040

4141

42-
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
42+
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int, original_attention_head_dim: int):
43+
44+
# 1. Calculate NEW sub-dimensions (The target shapes)
45+
# e.g., for 256: h=84, w=84, t=88
4346
h_dim = w_dim = 2 * (attention_head_dim // 6)
4447
t_dim = attention_head_dim - h_dim - w_dim
48+
current_dims = [t_dim, h_dim, w_dim]
49+
50+
# 2. Calculate OLD sub-dimensions (For interpolation reference)
51+
# e.g., for 128: h=42, w=42, t=44
52+
h_dim_old = w_dim_old = 2 * (original_attention_head_dim // 6)
53+
t_dim_old = original_attention_head_dim - h_dim_old - w_dim_old
54+
old_dims = [t_dim_old, h_dim_old, w_dim_old]
55+
4556
freqs = []
46-
for dim in [t_dim, h_dim, w_dim]:
47-
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float32, use_real=False)
57+
58+
for dim, old_dim in zip(current_dims, old_dims):
59+
freq = get_1d_rotary_pos_embed(
60+
dim=dim, # new size
61+
pos=max_seq_len,
62+
theta=theta,
63+
freqs_dtype=jnp.float32,
64+
use_real=False,
65+
original_dim=old_dim
66+
)
4867
freqs.append(freq)
68+
4969
freqs = jnp.concatenate(freqs, axis=1)
5070
t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6)
5171
hw_size = attention_head_dim // 6
@@ -61,8 +81,16 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
6181

6282
class WanRotaryPosEmbed(nnx.Module):
6383

64-
def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0):
84+
def __init__(
85+
self,
86+
attention_head_dim: int,
87+
original_attention_head_dim: int,
88+
patch_size: Tuple[int, int, int],
89+
max_seq_len: int,
90+
theta: float = 10000.0
91+
):
6592
self.attention_head_dim = attention_head_dim
93+
self.original_attention_head_dim = original_attention_head_dim
6694
self.patch_size = patch_size
6795
self.max_seq_len = max_seq_len
6896
self.theta = theta
@@ -72,7 +100,7 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
72100
p_t, p_h, p_w = self.patch_size
73101
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
74102

75-
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim)
103+
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim, self.original_attention_head_dim)
76104

77105
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
78106
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
@@ -378,6 +406,7 @@ class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
378406
def __init__(
379407
self,
380408
rngs: nnx.Rngs,
409+
target_head_dim: int,
381410
model_type="t2v",
382411
patch_size: Tuple[int] = (1, 2, 2),
383412
num_attention_heads: int = 40,
@@ -408,13 +437,13 @@ def __init__(
408437
names_which_can_be_offloaded: list = [],
409438
scan_layers: bool = True,
410439
):
411-
inner_dim = num_attention_heads * attention_head_dim
440+
inner_dim = num_attention_heads * target_head_dim
412441
out_channels = out_channels or in_channels
413442
self.num_layers = num_layers
414443
self.scan_layers = scan_layers
415444

416445
# 1. Patch & position embedding
417-
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
446+
self.rope = WanRotaryPosEmbed(target_head_dim, attention_head_dim, patch_size, rope_max_seq_len)
418447
self.patch_embedding = nnx.Conv(
419448
in_channels,
420449
inner_dim,

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import flax.linen as nn
2323
from flax import nnx
2424
from flax.linen import partitioning as nn_partitioning
25+
from flax.traverse_util import flatten_dict, unflatten_dict
2526
from ...pyconfig import HyperParameters
2627
from ... import max_logging
2728
from ... import max_utils
@@ -86,6 +87,42 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
8687
vs.sharding_rules = logical_axis_rules
8788
return vs
8889

90+
def perform_wan_scaling_surgery(params, target_head_dim, source_head_dim):
91+
"""
92+
scales Q and K weights to preserve attention entropy when
93+
changing head dimensions.
94+
95+
Formula: correction_factor = (target_dim / source_dim)^0.25
96+
"""
97+
98+
if target_head_dim == source_head_dim:
99+
print("Target and Source head dims are identical. Skipping surgery.")
100+
return params
101+
102+
# Calculate the factor
103+
# Example: (256 / 128)^0.25 = 2^0.25 ≈ 1.1892
104+
ratio = target_head_dim / source_head_dim
105+
correction_factor = ratio ** 0.25
106+
107+
flat_params = flatten_dict(params, sep='/')
108+
new_flat_params = {}
109+
modified_count = 0
110+
111+
for key, tensor in flat_params.items():
112+
# Key format example: 'transformer_blocks/0/attn1/query/kernel'
113+
# Identify Query and Key kernels.
114+
if ('query' in key or 'key' in key) and 'kernel' in key:
115+
# Ensure we are targeting attention layers, not other projections
116+
if 'attn' in key:
117+
new_flat_params[key] = tensor * correction_factor
118+
modified_count += 1
119+
else:
120+
new_flat_params[key] = tensor
121+
else:
122+
new_flat_params[key] = tensor
123+
124+
print(f"Surgery complete. Scaled {modified_count} tensors by {correction_factor:.4f}")
125+
return unflatten_dict(new_flat_params, sep='/')
89126

90127
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
91128
def create_sharded_logical_transformer(
@@ -113,6 +150,10 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
113150
wan_config["flash_min_seq_length"] = config.flash_min_seq_length
114151
wan_config["dropout"] = config.dropout
115152
wan_config["scan_layers"] = config.scan_layers
153+
wan_config["target_head_dim"] = wan_config["attention_head_dim"]
154+
if config.override_model_dims:
155+
wan_config["target_head_dim"] = config.target_head_dim
156+
wan_config["num_attention_heads"] = config.target_num_heads
116157

117158
# 2. eval_shape - will not use flops or create weights on device
118159
# thus not using HBM memory.
@@ -144,6 +185,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
144185
scan_layers=config.scan_layers,
145186
subfolder=subfolder,
146187
)
188+
if config.override_model_dims:
189+
params = perform_wan_scaling_surgery(params, config.target_head_dim, wan_config["attention_head_dim"])
147190

148191
params = jax.tree_util.tree_map_with_path(
149192
lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import datetime
1919
import functools
20-
from pprint import pprint
20+
import pprint
2121
import numpy as np
2222
import threading
2323
from concurrent.futures import ThreadPoolExecutor

0 commit comments

Comments
 (0)