Skip to content

Commit 60f96a7

Browse files
committed
wan_checkpointer_test.py corrected
1 parent 62c208f commit 60f96a7

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/tests/wan_checkpointer_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,13 +312,13 @@ def test_load_checkpoint_with_optimizer(self, mock_from_checkpoint, mock_create_
312312
mock_create_manager.return_value = mock_manager
313313

314314
mock_pipeline_instance = MagicMock()
315-
mock_wan_pipeline_i2v_2p1.from_checkpoint.return_value = mock_pipeline_instance
315+
mock_from_checkpoint.return_value = mock_pipeline_instance
316316

317317
checkpointer = WanCheckpointerI2V_2_1(config=self.config)
318318
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
319319

320320
mock_manager.restore.assert_called_once()
321-
mock_wan_pipeline_i2v_2p1.from_checkpoint.assert_called_once_with(self.config, restored_mock)
321+
mock_from_checkpoint.assert_called_once_with(self.config, restored_mock)
322322
self.assertEqual(pipeline, mock_pipeline_instance)
323323
self.assertIsNotNone(opt_state)
324324
self.assertEqual(opt_state["learning_rate"], 0.001)

0 commit comments

Comments
 (0)