1+ """
2+ Copyright 2025 Google LLC
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+ https://www.apache.org/licenses/LICENSE-2.0
7+ Unless required by applicable law or agreed to in writing, software
8+ distributed under the License is distributed on an "AS IS" BASIS,
9+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+ See the License for the specific language governing permissions and
11+ limitations under the License.
12+ """
13+
14+ import unittest
15+ from unittest .mock import patch , MagicMock
16+
17+ from maxdiffusion .checkpointing .wan_checkpointer2_2 import WanCheckpointer , WAN_CHECKPOINT
18+
19+
20+ class WanCheckpointerTest (unittest .TestCase ):
21+
22+ def setUp (self ):
23+ self .config = MagicMock ()
24+ self .config .checkpoint_dir = "/tmp/wan_checkpoint_test"
25+ self .config .dataset_type = "test_dataset"
26+
27+ @patch ("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager" )
28+ @patch ("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline" )
29+ def test_load_from_diffusers (self , mock_wan_pipeline , mock_create_manager ):
30+ mock_manager = MagicMock ()
31+ mock_manager .latest_step .return_value = None
32+ mock_create_manager .return_value = mock_manager
33+
34+ mock_pipeline_instance = MagicMock ()
35+ mock_wan_pipeline .from_pretrained .return_value = mock_pipeline_instance
36+
37+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
38+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = None )
39+
40+ mock_manager .latest_step .assert_called_once ()
41+ mock_wan_pipeline .from_pretrained .assert_called_once_with (self .config )
42+ self .assertEqual (pipeline , mock_pipeline_instance )
43+ self .assertIsNone (opt_state )
44+ self .assertIsNone (step )
45+
46+ @patch ("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager" )
47+ @patch ("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline" )
48+ def test_load_checkpoint_no_optimizer (self , mock_wan_pipeline , mock_create_manager ):
49+ mock_manager = MagicMock ()
50+ mock_manager .latest_step .return_value = 1
51+ metadata_mock = MagicMock ()
52+ metadata_mock .low_noise_transformer_state = {}
53+ metadata_mock .high_noise_transformer_state = {}
54+ mock_manager .item_metadata .return_value = metadata_mock
55+
56+ restored_mock = MagicMock ()
57+ restored_mock .low_noise_transformer_state = {"params" : {}}
58+ restored_mock .high_noise_transformer_state = {"params" : {}}
59+ restored_mock .wan_config = {}
60+ restored_mock .keys .return_value = ["low_noise_transformer_state" , "high_noise_transformer_state" , "wan_config" ]
61+
62+ mock_manager .restore .return_value = restored_mock
63+
64+ mock_create_manager .return_value = mock_manager
65+
66+ mock_pipeline_instance = MagicMock ()
67+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
68+
69+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
70+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
71+
72+ mock_manager .restore .assert_called_once_with (directory = unittest .mock .ANY , step = 1 , args = unittest .mock .ANY )
73+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
74+ self .assertEqual (pipeline , mock_pipeline_instance )
75+ self .assertIsNone (opt_state )
76+ self .assertEqual (step , 1 )
77+
78+ @patch ("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager" )
79+ @patch ("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline" )
80+ def test_load_checkpoint_with_optimizer (self , mock_wan_pipeline , mock_create_manager ):
81+ mock_manager = MagicMock ()
82+ mock_manager .latest_step .return_value = 1
83+ metadata_mock = MagicMock ()
84+ metadata_mock .low_noise_transformer_state = {}
85+ metadata_mock .high_noise_transformer_state = {}
86+ mock_manager .item_metadata .return_value = metadata_mock
87+
88+ restored_mock = MagicMock ()
89+ restored_mock .low_noise_transformer_state = {"params" : {}, "opt_state" : {"learning_rate" : 0.001 }}
90+ restored_mock .high_noise_transformer_state = {"params" : {}}
91+ restored_mock .wan_config = {}
92+ restored_mock .keys .return_value = ["low_noise_transformer_state" , "high_noise_transformer_state" , "wan_config" ]
93+
94+ mock_manager .restore .return_value = restored_mock
95+
96+ mock_create_manager .return_value = mock_manager
97+
98+ mock_pipeline_instance = MagicMock ()
99+ mock_wan_pipeline .from_checkpoint .return_value = mock_pipeline_instance
100+
101+ checkpointer = WanCheckpointer (self .config , WAN_CHECKPOINT )
102+ pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
103+
104+ mock_manager .restore .assert_called_once_with (directory = unittest .mock .ANY , step = 1 , args = unittest .mock .ANY )
105+ mock_wan_pipeline .from_checkpoint .assert_called_with (self .config , mock_manager .restore .return_value )
106+ self .assertEqual (pipeline , mock_pipeline_instance )
107+ self .assertIsNotNone (opt_state )
108+ self .assertEqual (opt_state ["learning_rate" ], 0.001 )
109+ self .assertEqual (step , 1 )
110+
111+
112+ if __name__ == "__main__" :
113+ unittest .main ()
0 commit comments