Skip to content

Commit ae7937f

Browse files
committed
wan_checkpointer_test.py corrected
1 parent 60f96a7 commit ae7937f

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/tests/wan_checkpointer_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,12 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline_i2v
408408
checkpointer = WanCheckpointerI2V_2_2(config=self.config)
409409
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
410410

411+
mock_manager.restore.assert_called_once()
412+
mock_wan_pipeline_i2v_2p2.from_checkpoint.asset_called_once_with(self.config, restored_mock)
413+
self.assertEqual(pipeline, mock_pipeline_instance)
411414
self.assertIsNotNone(opt_state)
412415
self.assertEqual(opt_state["learning_rate"], 0.001)
416+
self.assertEqual(step,1)
413417

414418
@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
415419
@patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2")
@@ -436,8 +440,12 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline_i2
436440
checkpointer = WanCheckpointerI2V_2_2(config=self.config)
437441
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
438442

443+
mock_manager.restore.assert_called_once()
444+
mock_wan_pipeline_i2v_2p2.from_checkpoint.assert_called_once_with(self.config, restored_mock)
445+
self.assertEqual(pipeline, mock_pipeline_instance)
439446
self.assertIsNotNone(opt_state)
440447
self.assertEqual(opt_state["learning_rate"], 0.002)
448+
self.assertEqual(step, 1)
441449

442450
class WanCheckpointerEdgeCasesTest(unittest.TestCase):
443451
"""Tests for edge cases and error handling."""

0 commit comments

Comments
 (0)