Skip to content

Commit 6907953

Browse files
authored
Added support for WAN 2.2 model (#281)
* Changes for WAN 2.2 * changes return type of checkpoint_loader to tuple * opt_state=None added * added model_name in config file * double noise computation fixed * support for wan2.1 in run_inference added * Support for WAN 2.2 added * Removed extra files * Updated README and generate_wan.py * Added tensorboard logging for inference metrics * Fixed duplicate pipeline loading * Merge conflicts * ruff errors * Changes to Wan trainer for compatibility with checkpointer * flash block size changed for testing * Revert "flash block size changed for testing" This reverts commit d6cdb1e. * Raise error for unsupported model training * Explicitly instantiate WanPipeline and WanCheckpointer subclasses * ruff errors * Added commit_id to tensorboard logging * Commit hash logging * Added enable_jax_named_scopes param for wan 2.2 * Wanpipeline and WanCheckpointer split into files * ruff errors * pytest errors fixed * pytest errors fixed * pytest errors fixed * wan_checkpointer_test.py fixed
1 parent f1ff3cc commit 6907953

16 files changed

Lines changed: 850 additions & 1257 deletions

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
2021
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
2122
- **`2025/10/14`**: NVIDIA DGX Spark Flux support.
2223
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 24 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,33 @@
1414
limitations under the License.
1515
"""
1616

17-
from abc import ABC
18-
import json
19-
20-
import jax
21-
import numpy as np
17+
from abc import ABC, abstractmethod
2218
from typing import Optional, Tuple
2319
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
24-
from ..pipelines.wan.wan_pipeline import WanPipeline
20+
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
21+
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
2522
from .. import max_logging, max_utils
2623
import orbax.checkpoint as ocp
27-
from etils import epath
24+
2825

2926
WAN_CHECKPOINT = "WAN_CHECKPOINT"
3027

3128

3229
class WanCheckpointer(ABC):
3330

34-
def __init__(self, config, checkpoint_type):
31+
def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT):
3532
self.config = config
3633
self.checkpoint_type = checkpoint_type
3734
self.opt_state = None
3835

39-
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
40-
self.config.checkpoint_dir,
41-
enable_checkpointing=True,
42-
save_interval_steps=1,
43-
checkpoint_type=checkpoint_type,
44-
dataset_type=config.dataset_type,
36+
self.checkpoint_manager: ocp.CheckpointManager = (
37+
create_orbax_checkpoint_manager(
38+
self.config.checkpoint_dir,
39+
enable_checkpointing=True,
40+
save_interval_steps=1,
41+
checkpoint_type=checkpoint_type,
42+
dataset_type=config.dataset_type,
43+
)
4544
)
4645

4746
def _create_optimizer(self, model, config, learning_rate):
@@ -51,76 +50,23 @@ def _create_optimizer(self, model, config, learning_rate):
5150
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
5251
return tx, learning_rate_scheduler
5352

53+
@abstractmethod
5454
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
55-
if step is None:
56-
step = self.checkpoint_manager.latest_step()
57-
max_logging.log(f"Latest WAN checkpoint step: {step}")
58-
if step is None:
59-
max_logging.log("No WAN checkpoint found.")
60-
return None, None
61-
max_logging.log(f"Loading WAN checkpoint from step {step}")
62-
metadatas = self.checkpoint_manager.item_metadata(step)
63-
transformer_metadata = metadatas.wan_state
64-
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
65-
params_restore = ocp.args.PyTreeRestore(
66-
restore_args=jax.tree.map(
67-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
68-
abstract_tree_structure_params,
69-
)
70-
)
71-
72-
max_logging.log("Restoring WAN checkpoint")
73-
restored_checkpoint = self.checkpoint_manager.restore(
74-
directory=epath.Path(self.config.checkpoint_dir),
75-
step=step,
76-
args=ocp.args.Composite(
77-
wan_state=params_restore,
78-
wan_config=ocp.args.JsonRestore(),
79-
),
80-
)
81-
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
82-
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
83-
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
84-
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
85-
return restored_checkpoint, step
55+
raise NotImplementedError
8656

57+
@abstractmethod
8758
def load_diffusers_checkpoint(self):
88-
pipeline = WanPipeline.from_pretrained(self.config)
89-
return pipeline
90-
91-
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
92-
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
93-
opt_state = None
94-
if restored_checkpoint:
95-
max_logging.log("Loading WAN pipeline from checkpoint")
96-
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
97-
if "opt_state" in restored_checkpoint["wan_state"].keys():
98-
opt_state = restored_checkpoint["wan_state"]["opt_state"]
99-
else:
100-
max_logging.log("No checkpoint found, loading default pipeline.")
101-
pipeline = self.load_diffusers_checkpoint()
102-
103-
return pipeline, opt_state, step
104-
105-
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
106-
"""Saves the training state and model configurations."""
107-
108-
def config_to_json(model_or_config):
109-
return json.loads(model_or_config.to_json_string())
110-
111-
max_logging.log(f"Saving checkpoint for step {train_step}")
112-
items = {
113-
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
114-
}
115-
116-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
59+
raise NotImplementedError
11760

118-
# Save the checkpoint
119-
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
120-
max_logging.log(f"Checkpoint for step {train_step} saved.")
61+
@abstractmethod
62+
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]:
63+
raise NotImplementedError
12164

65+
@abstractmethod
66+
def save_checkpoint(self, train_step, pipeline, train_states: dict):
67+
raise NotImplementedError
12268

123-
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
69+
def save_checkpoint_orig(self, train_step, pipeline, train_states: dict):
12470
"""Saves the training state and model configurations."""
12571

12672
def config_to_json(model_or_config):
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
22+
from .. import max_logging
23+
import orbax.checkpoint as ocp
24+
from etils import epath
25+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
26+
27+
class WanCheckpointer2_1(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
metadatas = self.checkpoint_manager.item_metadata(step)
38+
transformer_metadata = metadatas.wan_state
39+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
40+
params_restore = ocp.args.PyTreeRestore(
41+
restore_args=jax.tree.map(
42+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
43+
abstract_tree_structure_params,
44+
)
45+
)
46+
47+
max_logging.log("Restoring WAN checkpoint")
48+
restored_checkpoint = self.checkpoint_manager.restore(
49+
directory=epath.Path(self.config.checkpoint_dir),
50+
step=step,
51+
args=ocp.args.Composite(
52+
wan_state=params_restore,
53+
wan_config=ocp.args.JsonRestore(),
54+
),
55+
)
56+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
57+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
58+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
59+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
60+
return restored_checkpoint, step
61+
62+
def load_diffusers_checkpoint(self):
63+
pipeline = WanPipeline2_1.from_pretrained(self.config)
64+
return pipeline
65+
66+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]:
67+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
68+
opt_state = None
69+
if restored_checkpoint:
70+
max_logging.log("Loading WAN pipeline from checkpoint")
71+
pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
72+
if "opt_state" in restored_checkpoint.wan_state.keys():
73+
opt_state = restored_checkpoint.wan_state["opt_state"]
74+
else:
75+
max_logging.log("No checkpoint found, loading default pipeline.")
76+
pipeline = self.load_diffusers_checkpoint()
77+
78+
return pipeline, opt_state, step
79+
80+
def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict):
81+
"""Saves the training state and model configurations."""
82+
83+
def config_to_json(model_or_config):
84+
return json.loads(model_or_config.to_json_string())
85+
86+
max_logging.log(f"Saving checkpoint for step {train_step}")
87+
items = {
88+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
89+
}
90+
91+
items["wan_state"] = ocp.args.PyTreeSave(train_states)
92+
93+
# Save the checkpoint
94+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
95+
max_logging.log(f"Checkpoint for step {train_step} saved.")

0 commit comments

Comments
 (0)