1+ """
2+ Copyright 2025 Google LLC
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+ https://www.apache.org/licenses/LICENSE-2.0
7+ Unless required by applicable law or agreed to in writing, software
8+ distributed under the License is distributed on an "AS IS" BASIS,
9+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+ See the License for the specific language governing permissions and
11+ limitations under the License.
12+ """
13+
14+ import unittest
15+ from unittest .mock import patch , MagicMock
16+
17+ from maxdiffusion .checkpointing .wan_checkpointer import WanCheckpointer , WAN_CHECKPOINT
18+
19+ class WanCheckpointerTest (unittest .TestCase ):
20+ def setUp (self ):
21+ self .config = MagicMock ()
22+ self .config .checkpoint_dir = "/tmp/wan_checkpoint_test"
23+ self .config .dataset_type = "test_dataset"
24+
25+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
26+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
27+ def test_load_from_diffusers (self , mock_wan_pipeline , mock_create_manager ):
28+ mock_manager = MagicMock ()
29+ mock_manager .latest_step .return_value = None
30+ mock_create_manager .return_value = mock_manager
31+
32+ mock_pipeline_instance = MagicMock ()
33+ mock_wan_pipeline .from_pretrained .return_value = mock_pipeline_instance
34+
35+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
36+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = None )
37+
38+ mock_manager .latest_step .assert_called_once ()
39+ mock_wan_pipeline .from_pretrained .assert_called_once_with (self .config )
40+ self .assertEqual (pipeline , mock_pipeline_instance )
41+ self .assertIsNone (opt_state )
42+ self .assertIsNone (step )
43+
44+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
45+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
46+ def test_load_checkpoint_no_optimizer (self , mock_wan_pipeline , mock_create_manager ):
47+ mock_manager = MagicMock ()
48+ mock_manager .latest_step .return_value = 1
49+ metadata_mock = MagicMock ()
50+ metadata_mock .wan_state = {}
51+ mock_manager .item_metadata .return_value = metadata_mock
52+
53+ restored_mock = MagicMock ()
54+ restored_mock .wan_state = {'params' : {}}
55+ restored_mock .wan_config = {}
56+ restored_mock .keys .return_value = ['wan_state' , 'wan_config' ]
57+ mock_manager .restore .return_value = restored_mock
58+
59+ mock_create_manager .return_value = mock_manager
60+
61+ mock_pipeline_instance = MagicMock ()
62+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
63+
64+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
65+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
66+
67+ mock_manager .restore .assert_called_once_with (
68+ directory = unittest .mock .ANY ,
69+ step = 1 ,
70+ args = unittest .mock .ANY
71+ )
72+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
73+ self .assertEqual (pipeline , mock_pipeline_instance )
74+ self .assertIsNone (opt_state )
75+ self .assertEqual (step , 1 )
76+
77+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager' )
78+ @patch ('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline' )
79+ def test_load_checkpoint_with_optimizer (self , mock_wan_pipeline , mock_create_manager ):
80+ mock_manager = MagicMock ()
81+ mock_manager .latest_step .return_value = 1
82+ metadata_mock = MagicMock ()
83+ metadata_mock .wan_state = {}
84+ mock_manager .item_metadata .return_value = metadata_mock
85+
86+ restored_mock = MagicMock ()
87+ restored_mock .wan_state = {'params' : {}, 'opt_state' : {'learning_rate' : 0.001 }}
88+ restored_mock .wan_config = {}
89+ restored_mock .keys .return_value = ['wan_state' , 'wan_config' ]
90+ mock_manager .restore .return_value = restored_mock
91+
92+ mock_create_manager .return_value = mock_manager
93+
94+ mock_pipeline_instance = MagicMock ()
95+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
96+
97+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
98+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
99+
100+ mock_manager .restore .assert_called_once_with (
101+ directory = unittest .mock .ANY ,
102+ step = 1 ,
103+ args = unittest .mock .ANY
104+ )
105+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
106+ self .assertEqual (pipeline , mock_pipeline_instance )
107+ self .assertIsNotNone (opt_state )
108+ self .assertEqual (opt_state ['learning_rate' ], 0.001 )
109+ self .assertEqual (step , 1 )
110+
111+ if __name__ == "__main__" :
112+ unittest .main ()
0 commit comments