1+ """
2+ Copyright 2025 Google LLC
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+ https://www.apache.org/licenses/LICENSE-2.0
7+ Unless required by applicable law or agreed to in writing, software
8+ distributed under the License is distributed on an "AS IS" BASIS,
9+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+ See the License for the specific language governing permissions and
11+ limitations under the License.
12+ """
13+
14+ import unittest
15+ from unittest .mock import patch , MagicMock
16+
17+ from maxdiffusion .checkpointing .wan_checkpointer import WanCheckpointer , WAN_CHECKPOINT
18+
19+ class WanCheckpointerTest (unittest .TestCase ):
20+ def setUp (self ):
21+ self .config = MagicMock ()
22+ self .config .checkpoint_dir = "/tmp/wan_checkpoint_test"
23+ self .config .dataset_type = "test_dataset"
24+
25+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
26+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
27+ def test_load_from_diffusers (self , mock_wan_pipeline , mock_create_manager ):
28+ mock_manager = MagicMock ()
29+ mock_manager .latest_step .return_value = None
30+ mock_create_manager .return_value = mock_manager
31+
32+ mock_pipeline_instance = MagicMock ()
33+ mock_wan_pipeline .from_pretrained .return_value = mock_pipeline_instance
34+
35+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
36+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = None )
37+
38+ mock_manager .latest_step .assert_called_once ()
39+ mock_wan_pipeline .from_pretrained .assert_called_once_with (self .config )
40+ self .assertEqual (pipeline , mock_pipeline_instance )
41+ self .assertIsNone (opt_state )
42+ self .assertIsNone (step )
43+
44+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
45+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
46+ def test_load_checkpoint_no_optimizer (self , mock_wan_pipeline , mock_create_manager ):
47+ mock_manager = MagicMock ()
48+ mock_manager .latest_step .return_value = 1
49+ metadata_mock = MagicMock ()
50+ metadata_mock .wan_state = {}
51+ mock_manager .item_metadata .return_value = metadata_mock
52+ mock_manager .restore .return_value = {
53+ 'wan_state' : {'params' : {}},
54+ 'wan_config' : {}
55+ }
56+ mock_create_manager .return_value = mock_manager
57+
58+ mock_pipeline_instance = MagicMock ()
59+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
60+
61+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
62+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
63+
64+ mock_manager .restore .assert_called_once_with (
65+ directory = unittest .mock .ANY ,
66+ step = 1 ,
67+ args = unittest .mock .ANY
68+ )
69+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
70+ self .assertEqual (pipeline , mock_pipeline_instance )
71+ self .assertIsNone (opt_state )
72+ self .assertEqual (step , 1 )
73+
74+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
75+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
76+ def test_load_checkpoint_with_optimizer (self , mock_wan_pipeline , mock_create_manager ):
77+ mock_manager = MagicMock ()
78+ mock_manager .latest_step .return_value = 1
79+ metadata_mock = MagicMock ()
80+ metadata_mock .wan_state = {}
81+ mock_manager .item_metadata .return_value = metadata_mock
82+ mock_manager .restore .return_value = {
83+ 'wan_state' : {'params' : {}, 'opt_state' : {'learning_rate' : 0.001 }},
84+ 'wan_config' : {}
85+ }
86+ mock_create_manager .return_value = mock_manager
87+
88+ mock_pipeline_instance = MagicMock ()
89+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
90+
91+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
92+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
93+
94+ mock_manager .restore .assert_called_once_with (
95+ directory = unittest .mock .ANY ,
96+ step = 1 ,
97+ args = unittest .mock .ANY
98+ )
99+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
100+ self .assertEqual (pipeline , mock_pipeline_instance )
101+ self .assertIsNotNone (opt_state )
102+ self .assertEqual (opt_state ['learning_rate' ], 0.001 )
103+ self .assertEqual (step , 1 )
104+
105+ if __name__ == "__main__" :
106+ unittest .main ()
0 commit comments