1+ """
2+ Copyright 2025 Google LLC
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+ https://www.apache.org/licenses/LICENSE-2.0
7+ Unless required by applicable law or agreed to in writing, software
8+ distributed under the License is distributed on an "AS IS" BASIS,
9+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+ See the License for the specific language governing permissions and
11+ limitations under the License.
12+ """
13+
14+ import unittest
15+ from unittest .mock import patch , MagicMock
16+
17+ from maxdiffusion .checkpointing .wan_checkpointer import WanCheckpointer , WAN_CHECKPOINT
18+
19+ class WanCheckpointerTest (unittest .TestCase ):
20+ def setUp (self ):
21+ self .config = MagicMock ()
22+ self .config .checkpoint_dir = "/tmp/wan_checkpoint_test"
23+ self .config .dataset_type = "test_dataset"
24+
25+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
26+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
27+ def test_load_from_diffusers (self , mock_wan_pipeline , mock_create_manager ):
28+ mock_manager = MagicMock ()
29+ mock_manager .latest_step .return_value = None
30+ mock_create_manager .return_value = mock_manager
31+
32+ mock_pipeline_instance = MagicMock ()
33+ mock_wan_pipeline .from_pretrained .return_value = mock_pipeline_instance
34+
35+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
36+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = None )
37+
38+ mock_manager .latest_step .assert_called_once ()
39+ mock_wan_pipeline .from_pretrained .assert_called_once_with (self .config )
40+ self .assertEqual (pipeline , mock_pipeline_instance )
41+ self .assertIsNone (opt_state )
42+ self .assertIsNone (step )
43+
44+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
45+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
46+ def test_load_checkpoint_no_optimizer (self , mock_wan_pipeline , mock_create_manager ):
47+ mock_manager = MagicMock ()
48+ mock_manager .latest_step .return_value = 1
49+ mock_manager .item_metadata .return_value = MagicMock ()
50+ mock_manager .restore .return_value = {
51+ 'wan_state' : {'params' : {}},
52+ 'wan_config' : {}
53+ }
54+ mock_create_manager .return_value = mock_manager
55+
56+ mock_pipeline_instance = MagicMock ()
57+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
58+
59+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
60+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
61+
62+ mock_manager .restore .assert_called_once_with (
63+ directory = unittest .mock .ANY ,
64+ step = 1 ,
65+ args = unittest .mock .ANY
66+ )
67+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
68+ self .assertEqual (pipeline , mock_pipeline_instance )
69+ self .assertIsNone (opt_state )
70+ self .assertEqual (step , 1 )
71+
72+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
73+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
74+ def test_load_checkpoint_with_optimizer (self , mock_wan_pipeline , mock_create_manager ):
75+ mock_manager = MagicMock ()
76+ mock_manager .latest_step .return_value = 1
77+ mock_manager .item_metadata .return_value = MagicMock ()
78+ mock_manager .restore .return_value = {
79+ 'wan_state' : {'params' : {}, 'opt_state' : {'learning_rate' : 0.001 }},
80+ 'wan_config' : {}
81+ }
82+ mock_create_manager .return_value = mock_manager
83+
84+ mock_pipeline_instance = MagicMock ()
85+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
86+
87+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
88+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
89+
90+ mock_manager .restore .assert_called_once_with (
91+ directory = unittest .mock .ANY ,
92+ step = 1 ,
93+ args = unittest .mock .ANY
94+ )
95+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
96+ self .assertEqual (pipeline , mock_pipeline_instance )
97+ self .assertIsNotNone (opt_state )
98+ self .assertEqual (opt_state ['learning_rate' ], 0.001 )
99+ self .assertEqual (step , 1 )
100+
101+ if __name__ == "__main__" :
102+ unittest .main ()
0 commit comments