@@ -408,8 +408,12 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline_i2v
408408 checkpointer = WanCheckpointerI2V_2_2 (config = self .config )
409409 pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
410410
411+ mock_manager .restore .assert_called_once ()
412+ mock_wan_pipeline_i2v_2p2 .from_checkpoint .asset_called_once_with (self .config , restored_mock )
413+ self .assertEqual (pipeline , mock_pipeline_instance )
411414 self .assertIsNotNone (opt_state )
412415 self .assertEqual (opt_state ["learning_rate" ], 0.001 )
416+ self .assertEqual (step ,1 )
413417
414418 @patch ("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager" )
415419 @patch ("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2" )
@@ -436,8 +440,12 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline_i2
436440 checkpointer = WanCheckpointerI2V_2_2 (config = self .config )
437441 pipeline , opt_state , step = checkpointer .load_checkpoint (step = 1 )
438442
443+ mock_manager .restore .assert_called_once ()
444+ mock_wan_pipeline_i2v_2p2 .from_checkpoint .assert_called_once_with (self .config , restored_mock )
445+ self .assertEqual (pipeline , mock_pipeline_instance )
439446 self .assertIsNotNone (opt_state )
440447 self .assertEqual (opt_state ["learning_rate" ], 0.002 )
448+ self .assertEqual (step , 1 )
441449
442450class WanCheckpointerEdgeCasesTest (unittest .TestCase ):
443451 """Tests for edge cases and error handling."""
0 commit comments