Skip to content

Commit 12a0fd2

Browse files
committed
Support for WAN 2.2
1 parent c9229c3 commit 12a0fd2

7 files changed

Lines changed: 1402 additions & 49 deletions

File tree

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
import json
19+
20+
import jax
21+
import numpy as np
22+
from typing import Optional, Tuple
23+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
24+
from ..pipelines.wan.wan_pipeline2_2 import WanPipeline
25+
from .. import max_logging, max_utils
26+
import orbax.checkpoint as ocp
27+
28+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
29+
30+
31+
class WanCheckpointer(ABC):
32+
33+
def __init__(self, config, checkpoint_type):
34+
self.config = config
35+
self.checkpoint_type = checkpoint_type
36+
self.opt_state = None
37+
38+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
39+
self.config.checkpoint_dir,
40+
enable_checkpointing=True,
41+
save_interval_steps=1,
42+
checkpoint_type=checkpoint_type,
43+
dataset_type=config.dataset_type,
44+
)
45+
46+
def _create_optimizer(self, model, config, learning_rate):
47+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
48+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
49+
)
50+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
51+
return tx, learning_rate_scheduler
52+
53+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
54+
if step is None:
55+
step = self.checkpoint_manager.latest_step()
56+
max_logging.log(f"Latest WAN checkpoint step: {step}")
57+
if step is None:
58+
max_logging.log("No WAN checkpoint found.")
59+
return None, None
60+
max_logging.log(f"Loading WAN checkpoint from step {step}")
61+
metadatas = self.checkpoint_manager.item_metadata(step)
62+
63+
restore_args = {}
64+
65+
low_state_metadata = metadatas.low_noise_transformer_state
66+
abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata)
67+
low_state_restore = ocp.args.PyTreeRestore(
68+
restore_args=jax.tree.map(
69+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
70+
abstract_tree_structure_low_state,
71+
)
72+
)
73+
restore_args["low_noise_transformer_state"] = low_state_restore
74+
75+
high_state_metadata = metadatas.high_noise_transformer_state
76+
abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata)
77+
high_state_restore = ocp.args.PyTreeRestore(
78+
restore_args=jax.tree.map(
79+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
80+
abstract_tree_structure_high_state,
81+
)
82+
)
83+
restore_args["high_noise_transformer_state"] = high_state_restore
84+
85+
restore_args["wan_config"] = ocp.args.JsonRestore()
86+
87+
max_logging.log("Restoring WAN 2.2 checkpoint")
88+
restored_checkpoint = self.checkpoint_manager.restore(
89+
step=step,
90+
args=ocp.args.Composite(**restore_args),
91+
)
92+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
93+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
94+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
95+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
96+
return restored_checkpoint, step
97+
98+
def load_diffusers_checkpoint(self):
99+
pipeline = WanPipeline.from_pretrained(self.config)
100+
return pipeline
101+
102+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
103+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
104+
opt_state = None
105+
if restored_checkpoint:
106+
max_logging.log("Loading WAN pipeline from checkpoint")
107+
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
108+
if "opt_state" in restored_checkpoint["wan_state"].keys():
109+
opt_state = restored_checkpoint["wan_state"]["opt_state"]
110+
else:
111+
max_logging.log("No checkpoint found, loading default pipeline.")
112+
pipeline = self.load_diffusers_checkpoint()
113+
114+
return pipeline, opt_state, step
115+
116+
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
117+
"""Saves the training state and model configurations."""
118+
119+
def config_to_json(model_or_config):
120+
return json.loads(model_or_config.to_json_string())
121+
122+
max_logging.log(f"Saving checkpoint for step {train_step}")
123+
items = {
124+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
125+
}
126+
127+
if "low_noise_transformer" in train_states:
128+
low_noise_state = train_states["low_noise_transformer"]
129+
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state)
130+
131+
if "high_noise_transformer" in train_states:
132+
high_noise_state = train_states["high_noise_transformer"]
133+
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state)
134+
135+
# Save the checkpoint
136+
if len(items) > 1:
137+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
138+
max_logging.log(f"Checkpoint for step {train_step} saved.")
139+
140+
141+
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
142+
"""Saves the training state and model configurations."""
143+
144+
def config_to_json(model_or_config):
145+
"""
146+
only save the config that is needed and can be serialized to JSON.
147+
"""
148+
if not hasattr(model_or_config, "config"):
149+
return None
150+
source_config = dict(model_or_config.config)
151+
152+
# 1. configs that can be serialized to JSON
153+
SAFE_KEYS = [
154+
"_class_name",
155+
"_diffusers_version",
156+
"model_type",
157+
"patch_size",
158+
"num_attention_heads",
159+
"attention_head_dim",
160+
"in_channels",
161+
"out_channels",
162+
"text_dim",
163+
"freq_dim",
164+
"ffn_dim",
165+
"num_layers",
166+
"cross_attn_norm",
167+
"qk_norm",
168+
"eps",
169+
"image_dim",
170+
"added_kv_proj_dim",
171+
"rope_max_seq_len",
172+
"pos_embed_seq_len",
173+
"flash_min_seq_length",
174+
"flash_block_sizes",
175+
"attention",
176+
"_use_default_values",
177+
]
178+
179+
# 2. save the config that are in the SAFE_KEYS list
180+
clean_config = {}
181+
for key in SAFE_KEYS:
182+
if key in source_config:
183+
clean_config[key] = source_config[key]
184+
185+
# 3. deal with special data type and precision
186+
if "dtype" in source_config and hasattr(source_config["dtype"], "name"):
187+
clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16'
188+
189+
if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"):
190+
clean_config["weights_dtype"] = source_config["weights_dtype"].name
191+
192+
if "precision" in source_config and isinstance(source_config["precision"]):
193+
clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST'
194+
195+
return clean_config
196+
197+
items_to_save = {
198+
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
199+
}
200+
201+
items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)
202+
203+
# Create CompositeArgs for Orbax
204+
save_args = ocp.args.Composite(**items_to_save)
205+
206+
# Save the checkpoint
207+
self.checkpoint_manager.save(train_step, args=save_args)
208+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ save_config_to_gcs: False
2828
log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
31+
model_name: wan2.1
3132

3233
# Overrides the transformer from pretrained_model_name_or_path
3334
wan_transformer_pretrained_model_name_or_path: ''

0 commit comments

Comments
 (0)