Skip to content

Commit 72a0b90

Browse files
committed
Bug fix, case where not checkpoint load from diffusers
1 parent feebbb1 commit 72a0b90

2 files changed

Lines changed: 107 additions & 3 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import jax
2121
import numpy as np
22+
from typing import Optional, Tuple
2223
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2324
from ..pipelines.wan.wan_pipeline import WanPipeline
2425
from .. import max_logging, max_utils
@@ -50,12 +51,13 @@ def _create_optimizer(self, model, config, learning_rate):
5051
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
5152
return tx, learning_rate_scheduler
5253

53-
def load_wan_configs_from_orbax(self, step):
54+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
5455
if step is None:
5556
step = self.checkpoint_manager.latest_step()
5657
max_logging.log(f"Latest WAN checkpoint step: {step}")
5758
if step is None:
58-
return None
59+
max_logging.log("No WAN checkpoint found.")
60+
return None, None
5961
max_logging.log(f"Loading WAN checkpoint from step {step}")
6062
metadatas = self.checkpoint_manager.item_metadata(step)
6163
transformer_metadata = metadatas.wan_state
@@ -86,7 +88,7 @@ def load_diffusers_checkpoint(self):
8688
pipeline = WanPipeline.from_pretrained(self.config)
8789
return pipeline
8890

89-
def load_checkpoint(self, step=None):
91+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
9092
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
9193
opt_state = None
9294
if restored_checkpoint:
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
import unittest
15+
from unittest.mock import patch, MagicMock
16+
17+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT
18+
19+
class WanCheckpointerTest(unittest.TestCase):
20+
def setUp(self):
21+
self.config = MagicMock()
22+
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
23+
self.config.dataset_type = "test_dataset"
24+
25+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
26+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
27+
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
28+
mock_manager = MagicMock()
29+
mock_manager.latest_step.return_value = None
30+
mock_create_manager.return_value = mock_manager
31+
32+
mock_pipeline_instance = MagicMock()
33+
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance
34+
35+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
36+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)
37+
38+
mock_manager.latest_step.assert_called_once()
39+
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
40+
self.assertEqual(pipeline, mock_pipeline_instance)
41+
self.assertIsNone(opt_state)
42+
self.assertIsNone(step)
43+
44+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
45+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
46+
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
47+
mock_manager = MagicMock()
48+
mock_manager.latest_step.return_value = 1
49+
mock_manager.item_metadata.return_value = MagicMock()
50+
mock_manager.restore.return_value = {
51+
'wan_state': {'params': {}},
52+
'wan_config': {}
53+
}
54+
mock_create_manager.return_value = mock_manager
55+
56+
mock_pipeline_instance = MagicMock()
57+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
58+
59+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
60+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
61+
62+
mock_manager.restore.assert_called_once_with(
63+
directory=unittest.mock.ANY,
64+
step=1,
65+
args=unittest.mock.ANY
66+
)
67+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
68+
self.assertEqual(pipeline, mock_pipeline_instance)
69+
self.assertIsNone(opt_state)
70+
self.assertEqual(step, 1)
71+
72+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
73+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
74+
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
75+
mock_manager = MagicMock()
76+
mock_manager.latest_step.return_value = 1
77+
mock_manager.item_metadata.return_value = MagicMock()
78+
mock_manager.restore.return_value = {
79+
'wan_state': {'params': {}, 'opt_state': {'learning_rate': 0.001}},
80+
'wan_config': {}
81+
}
82+
mock_create_manager.return_value = mock_manager
83+
84+
mock_pipeline_instance = MagicMock()
85+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
86+
87+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
88+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
89+
90+
mock_manager.restore.assert_called_once_with(
91+
directory=unittest.mock.ANY,
92+
step=1,
93+
args=unittest.mock.ANY
94+
)
95+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
96+
self.assertEqual(pipeline, mock_pipeline_instance)
97+
self.assertIsNotNone(opt_state)
98+
self.assertEqual(opt_state['learning_rate'], 0.001)
99+
self.assertEqual(step, 1)
100+
101+
if __name__ == "__main__":
102+
unittest.main()

0 commit comments

Comments
 (0)