Skip to content

Commit 2535742

Browse files
committed
WAN Img2Vid implementation base commit
1 parent 5cbf844 commit 2535742

17 files changed

Lines changed: 2070 additions & 74 deletions

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2020
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
2121
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
22+
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
23+
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
2224
from .. import max_logging, max_utils
2325
import orbax.checkpoint as ocp
2426

@@ -59,7 +61,7 @@ def load_diffusers_checkpoint(self):
5961
raise NotImplementedError
6062

6163
@abstractmethod
62-
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]:
64+
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int]]:
6365
raise NotImplementedError
6466

6567
@abstractmethod
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
22+
from .. import max_logging
23+
import orbax.checkpoint as ocp
24+
from etils import epath
25+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
26+
27+
class WanCheckpointerI2V_2_1(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
metadatas = self.checkpoint_manager.item_metadata(step)
38+
transformer_metadata = metadatas.wan_state
39+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
40+
params_restore = ocp.args.PyTreeRestore(
41+
restore_args=jax.tree.map(
42+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
43+
abstract_tree_structure_params,
44+
)
45+
)
46+
47+
max_logging.log("Restoring WAN checkpoint")
48+
restored_checkpoint = self.checkpoint_manager.restore(
49+
directory=epath.Path(self.config.checkpoint_dir),
50+
step=step,
51+
args=ocp.args.Composite(
52+
wan_state=params_restore,
53+
wan_config=ocp.args.JsonRestore(),
54+
),
55+
)
56+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
57+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
58+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
59+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
60+
return restored_checkpoint, step
61+
62+
def load_diffusers_checkpoint(self):
63+
pipeline = WanPipelineI2V_2_1.from_pretrained(self.config)
64+
return pipeline
65+
66+
def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_1, Optional[dict], Optional[int]]:
67+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
68+
opt_state = None
69+
if restored_checkpoint:
70+
max_logging.log("Loading WAN pipeline from checkpoint")
71+
pipeline = WanPipelineI2V_2_1.from_checkpoint(self.config, restored_checkpoint)
72+
if "opt_state" in restored_checkpoint.wan_state.keys():
73+
opt_state = restored_checkpoint.wan_state["opt_state"]
74+
else:
75+
max_logging.log("No checkpoint found, loading default pipeline.")
76+
pipeline = self.load_diffusers_checkpoint()
77+
78+
return pipeline, opt_state, step
79+
80+
def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_1, train_states: dict):
81+
"""Saves the training state and model configurations."""
82+
83+
def config_to_json(model_or_config):
84+
return json.loads(model_or_config.to_json_string())
85+
86+
max_logging.log(f"Saving checkpoint for step {train_step}")
87+
items = {
88+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
89+
}
90+
91+
items["wan_state"] = ocp.args.PyTreeSave(train_states)
92+
93+
# Save the checkpoint
94+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
95+
max_logging.log(f"Checkpoint for step {train_step} saved.")
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
22+
from .. import max_logging
23+
import orbax.checkpoint as ocp
24+
from etils import epath
25+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
26+
27+
class WanCheckpointerI2V_2_2(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
metadatas = self.checkpoint_manager.item_metadata(step)
38+
39+
# Handle low_noise_transformer
40+
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
41+
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
42+
low_params_restore = ocp.args.PyTreeRestore(
43+
restore_args=jax.tree.map(
44+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
45+
abstract_tree_structure_low_params,
46+
)
47+
)
48+
49+
# Handle high_noise_transformer
50+
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
51+
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
52+
high_params_restore = ocp.args.PyTreeRestore(
53+
restore_args=jax.tree.map(
54+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
55+
abstract_tree_structure_high_params,
56+
)
57+
)
58+
59+
max_logging.log("Restoring WAN 2.2 checkpoint")
60+
restored_checkpoint = self.checkpoint_manager.restore(
61+
directory=epath.Path(self.config.checkpoint_dir),
62+
step=step,
63+
args=ocp.args.Composite(
64+
low_noise_transformer_state=low_params_restore,
65+
high_noise_transformer_state=high_params_restore,
66+
wan_config=ocp.args.JsonRestore(),
67+
),
68+
)
69+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
70+
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
71+
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
72+
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
73+
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
74+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
75+
return restored_checkpoint, step
76+
77+
def load_diffusers_checkpoint(self):
78+
pipeline = WanPipelineI2V_2_2.from_pretrained(self.config)
79+
return pipeline
80+
81+
def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]:
82+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
83+
opt_state = None
84+
if restored_checkpoint:
85+
max_logging.log("Loading WAN pipeline from checkpoint")
86+
pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint)
87+
# Check for optimizer state in either transformer
88+
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
89+
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
90+
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
91+
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
92+
else:
93+
max_logging.log("No checkpoint found, loading default pipeline.")
94+
pipeline = self.load_diffusers_checkpoint()
95+
96+
return pipeline, opt_state, step
97+
98+
def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_2, train_states: dict):
99+
"""Saves the training state and model configurations."""
100+
101+
def config_to_json(model_or_config):
102+
return json.loads(model_or_config.to_json_string())
103+
104+
max_logging.log(f"Saving checkpoint for step {train_step}")
105+
items = {
106+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
107+
}
108+
109+
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
110+
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
111+
112+
# Save the checkpoint
113+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
114+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
3131
model_name: wan2.1
32+
model_type: 'T2V'
3233

3334
# Overrides the transformer from pretrained_model_name_or_path
3435
wan_transformer_pretrained_model_name_or_path: ''

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers'
3131
model_name: wan2.2
32+
model_type: 'T2V'
3233

3334
# Overrides the transformer from pretrained_model_name_or_path
3435
wan_transformer_pretrained_model_name_or_path: ''

0 commit comments

Comments
 (0)