Skip to content

Commit 7d43d11

Browse files
committed
wan_checkpointer2_2.py modified
1 parent 12a0fd2 commit 7d43d11

2 files changed

Lines changed: 144 additions & 32 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer2_2.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..pipelines.wan.wan_pipeline2_2 import WanPipeline
2525
from .. import max_logging, max_utils
2626
import orbax.checkpoint as ocp
27+
from etils import epath
2728

2829
WAN_CHECKPOINT = "WAN_CHECKPOINT"
2930

@@ -59,39 +60,40 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
5960
return None, None
6061
max_logging.log(f"Loading WAN checkpoint from step {step}")
6162
metadatas = self.checkpoint_manager.item_metadata(step)
62-
63-
restore_args = {}
64-
65-
low_state_metadata = metadatas.low_noise_transformer_state
66-
abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata)
67-
low_state_restore = ocp.args.PyTreeRestore(
63+
64+
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
65+
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
66+
low_params_restore = ocp.args.PyTreeRestore(
6867
restore_args=jax.tree.map(
6968
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
70-
abstract_tree_structure_low_state,
69+
abstract_tree_structure_low_params,
7170
)
7271
)
73-
restore_args["low_noise_transformer_state"] = low_state_restore
74-
75-
high_state_metadata = metadatas.high_noise_transformer_state
76-
abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata)
77-
high_state_restore = ocp.args.PyTreeRestore(
72+
73+
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
74+
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
75+
high_params_restore = ocp.args.PyTreeRestore(
7876
restore_args=jax.tree.map(
7977
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
80-
abstract_tree_structure_high_state,
78+
abstract_tree_structure_high_params,
8179
)
8280
)
83-
restore_args["high_noise_transformer_state"] = high_state_restore
8481

85-
restore_args["wan_config"] = ocp.args.JsonRestore()
86-
87-
max_logging.log("Restoring WAN 2.2 checkpoint")
82+
max_logging.log("Restoring WAN checkpoint")
8883
restored_checkpoint = self.checkpoint_manager.restore(
84+
directory=epath.Path(self.config.checkpoint_dir),
8985
step=step,
90-
args=ocp.args.Composite(**restore_args),
86+
args=ocp.args.Composite(
87+
low_noise_transformer_state=low_params_restore,
88+
high_noise_transformer_state=high_params_restore,
89+
wan_config=ocp.args.JsonRestore(),
90+
),
9191
)
9292
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
93-
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
94-
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
93+
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
94+
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
95+
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
96+
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
9597
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
9698
return restored_checkpoint, step
9799

@@ -105,8 +107,11 @@ def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optio
105107
if restored_checkpoint:
106108
max_logging.log("Loading WAN pipeline from checkpoint")
107109
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
108-
if "opt_state" in restored_checkpoint["wan_state"].keys():
109-
opt_state = restored_checkpoint["wan_state"]["opt_state"]
110+
# Check for optimizer state in either transformer
111+
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
112+
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
113+
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
114+
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
110115
else:
111116
max_logging.log("No checkpoint found, loading default pipeline.")
112117
pipeline = self.load_diffusers_checkpoint()
@@ -124,18 +129,12 @@ def config_to_json(model_or_config):
124129
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
125130
}
126131

127-
if "low_noise_transformer" in train_states:
128-
low_noise_state = train_states["low_noise_transformer"]
129-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state)
130-
131-
if "high_noise_transformer" in train_states:
132-
high_noise_state = train_states["high_noise_transformer"]
133-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state)
132+
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
133+
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
134134

135135
# Save the checkpoint
136-
if len(items) > 1:
137-
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
138-
max_logging.log(f"Checkpoint for step {train_step} saved.")
136+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
137+
max_logging.log(f"Checkpoint for step {train_step} saved.")
139138

140139

141140
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
import unittest
15+
from unittest.mock import patch, MagicMock
16+
17+
from maxdiffusion.checkpointing.wan_checkpointer2_2 import WanCheckpointer, WAN_CHECKPOINT
18+
19+
20+
class WanCheckpointerTest(unittest.TestCase):
21+
22+
def setUp(self):
23+
self.config = MagicMock()
24+
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
25+
self.config.dataset_type = "test_dataset"
26+
27+
@patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager")
28+
@patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline")
29+
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
30+
mock_manager = MagicMock()
31+
mock_manager.latest_step.return_value = None
32+
mock_create_manager.return_value = mock_manager
33+
34+
mock_pipeline_instance = MagicMock()
35+
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance
36+
37+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
38+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)
39+
40+
mock_manager.latest_step.assert_called_once()
41+
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
42+
self.assertEqual(pipeline, mock_pipeline_instance)
43+
self.assertIsNone(opt_state)
44+
self.assertIsNone(step)
45+
46+
@patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager")
47+
@patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline")
48+
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
49+
mock_manager = MagicMock()
50+
mock_manager.latest_step.return_value = 1
51+
metadata_mock = MagicMock()
52+
metadata_mock.low_noise_transformer_state = {}
53+
metadata_mock.high_noise_transformer_state = {}
54+
mock_manager.item_metadata.return_value = metadata_mock
55+
56+
restored_mock = MagicMock()
57+
restored_mock.low_noise_transformer_state = {"params": {}}
58+
restored_mock.high_noise_transformer_state = {"params": {}}
59+
restored_mock.wan_config = {}
60+
restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"]
61+
62+
mock_manager.restore.return_value = restored_mock
63+
64+
mock_create_manager.return_value = mock_manager
65+
66+
mock_pipeline_instance = MagicMock()
67+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
68+
69+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
70+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
71+
72+
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
73+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
74+
self.assertEqual(pipeline, mock_pipeline_instance)
75+
self.assertIsNone(opt_state)
76+
self.assertEqual(step, 1)
77+
78+
@patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager")
79+
@patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline")
80+
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
81+
mock_manager = MagicMock()
82+
mock_manager.latest_step.return_value = 1
83+
metadata_mock = MagicMock()
84+
metadata_mock.low_noise_transformer_state = {}
85+
metadata_mock.high_noise_transformer_state = {}
86+
mock_manager.item_metadata.return_value = metadata_mock
87+
88+
restored_mock = MagicMock()
89+
restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}}
90+
restored_mock.high_noise_transformer_state = {"params": {}}
91+
restored_mock.wan_config = {}
92+
restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"]
93+
94+
mock_manager.restore.return_value = restored_mock
95+
96+
mock_create_manager.return_value = mock_manager
97+
98+
mock_pipeline_instance = MagicMock()
99+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
100+
101+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
102+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
103+
104+
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
105+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
106+
self.assertEqual(pipeline, mock_pipeline_instance)
107+
self.assertIsNotNone(opt_state)
108+
self.assertEqual(opt_state["learning_rate"], 0.001)
109+
self.assertEqual(step, 1)
110+
111+
112+
if __name__ == "__main__":
113+
unittest.main()

0 commit comments

Comments
 (0)