Skip to content

Commit 35c68ca

Browse files
committed
removed extra tests
1 parent 9e46e73 commit 35c68ca

1 file changed

Lines changed: 0 additions & 21 deletions

File tree

src/maxdiffusion/tests/wan_checkpointer_test.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -239,27 +239,6 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m
239239
self.assertEqual(opt_state["learning_rate"], 0.002)
240240
self.assertEqual(step, 1)
241241

242-
@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
243-
def test_save_checkpoint(self, mock_create_manager):
244-
"""Test saving checkpoint for WAN 2.2."""
245-
mock_manager = MagicMock()
246-
mock_create_manager.return_value = mock_manager
247-
248-
mock_pipeline = MagicMock()
249-
mock_pipeline.low_noise_transformer.to_json_string.return_value = '{"config": "test"}'
250-
251-
train_states = {
252-
"low_noise_transformer": {"params": {}, "opt_state": {}},
253-
"high_noise_transformer": {"params": {}}
254-
}
255-
256-
checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config)
257-
checkpointer.save_checkpoint(train_step=100, pipeline=mock_pipeline, train_states=train_states)
258-
259-
mock_manager.save.assert_called_once()
260-
call_args = mock_manager.save.call_args
261-
self.assertEqual(call_args[0][0], 100) # train_step
262-
263242

264243
class WanCheckpointerFactoryTest(unittest.TestCase):
265244
"""Tests for checkpointer factory/selection logic."""

0 commit comments

Comments
 (0)