Skip to content

Commit 9b623e4

Browse files
committed
Bug fix, case where not checkpoint load from diffusers
1 parent feebbb1 commit 9b623e4

3 files changed

Lines changed: 112 additions & 4 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest src/maxdiffusion/tests/wan_checkpointer_test.py
6262
# add_pull_ready:
6363
# if: github.ref != 'refs/heads/main'
6464
# permissions:

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import jax
2121
import numpy as np
22+
from typing import Optional, Tuple
2223
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2324
from ..pipelines.wan.wan_pipeline import WanPipeline
2425
from .. import max_logging, max_utils
@@ -50,12 +51,13 @@ def _create_optimizer(self, model, config, learning_rate):
5051
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
5152
return tx, learning_rate_scheduler
5253

53-
def load_wan_configs_from_orbax(self, step):
54+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
5455
if step is None:
5556
step = self.checkpoint_manager.latest_step()
5657
max_logging.log(f"Latest WAN checkpoint step: {step}")
5758
if step is None:
58-
return None
59+
max_logging.log("No WAN checkpoint found.")
60+
return None, None
5961
max_logging.log(f"Loading WAN checkpoint from step {step}")
6062
metadatas = self.checkpoint_manager.item_metadata(step)
6163
transformer_metadata = metadatas.wan_state
@@ -86,7 +88,7 @@ def load_diffusers_checkpoint(self):
8688
pipeline = WanPipeline.from_pretrained(self.config)
8789
return pipeline
8890

89-
def load_checkpoint(self, step=None):
91+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
9092
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
9193
opt_state = None
9294
if restored_checkpoint:
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
import unittest
15+
from unittest.mock import patch, MagicMock
16+
17+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT
18+
19+
class WanCheckpointerTest(unittest.TestCase):
20+
def setUp(self):
21+
self.config = MagicMock()
22+
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
23+
self.config.dataset_type = "test_dataset"
24+
25+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
26+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
27+
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
28+
mock_manager = MagicMock()
29+
mock_manager.latest_step.return_value = None
30+
mock_create_manager.return_value = mock_manager
31+
32+
mock_pipeline_instance = MagicMock()
33+
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance
34+
35+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
36+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)
37+
38+
mock_manager.latest_step.assert_called_once()
39+
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
40+
self.assertEqual(pipeline, mock_pipeline_instance)
41+
self.assertIsNone(opt_state)
42+
self.assertIsNone(step)
43+
44+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
45+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
46+
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
47+
mock_manager = MagicMock()
48+
mock_manager.latest_step.return_value = 1
49+
metadata_mock = MagicMock()
50+
metadata_mock.wan_state = {}
51+
mock_manager.item_metadata.return_value = metadata_mock
52+
mock_manager.restore.return_value = {
53+
'wan_state': {'params': {}},
54+
'wan_config': {}
55+
}
56+
mock_create_manager.return_value = mock_manager
57+
58+
mock_pipeline_instance = MagicMock()
59+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
60+
61+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
62+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
63+
64+
mock_manager.restore.assert_called_once_with(
65+
directory=unittest.mock.ANY,
66+
step=1,
67+
args=unittest.mock.ANY
68+
)
69+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
70+
self.assertEqual(pipeline, mock_pipeline_instance)
71+
self.assertIsNone(opt_state)
72+
self.assertEqual(step, 1)
73+
74+
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
75+
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
76+
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
77+
mock_manager = MagicMock()
78+
mock_manager.latest_step.return_value = 1
79+
metadata_mock = MagicMock()
80+
metadata_mock.wan_state = {}
81+
mock_manager.item_metadata.return_value = metadata_mock
82+
mock_manager.restore.return_value = {
83+
'wan_state': {'params': {}, 'opt_state': {'learning_rate': 0.001}},
84+
'wan_config': {}
85+
}
86+
mock_create_manager.return_value = mock_manager
87+
88+
mock_pipeline_instance = MagicMock()
89+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
90+
91+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
92+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
93+
94+
mock_manager.restore.assert_called_once_with(
95+
directory=unittest.mock.ANY,
96+
step=1,
97+
args=unittest.mock.ANY
98+
)
99+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
100+
self.assertEqual(pipeline, mock_pipeline_instance)
101+
self.assertIsNotNone(opt_state)
102+
self.assertEqual(opt_state['learning_rate'], 0.001)
103+
self.assertEqual(step, 1)
104+
105+
if __name__ == "__main__":
106+
unittest.main()

0 commit comments

Comments
 (0)