diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2c588b439..1512485b6 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -58,7 +58,7 @@ jobs: pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - name: PyTest run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index a1d029a43..0dd493a33 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -19,6 +19,7 @@ import jax import numpy as np +from typing import Optional, Tuple from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils @@ -50,12 +51,13 @@ def _create_optimizer(self, model, config, learning_rate): tx = max_utils.create_optimizer(config, learning_rate_scheduler) return tx, learning_rate_scheduler - def load_wan_configs_from_orbax(self, step): + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: step = self.checkpoint_manager.latest_step() max_logging.log(f"Latest WAN checkpoint step: {step}") if step is None: - return None + max_logging.log("No WAN checkpoint found.") + return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) transformer_metadata = metadatas.wan_state @@ -86,7 +88,7 @@ def load_diffusers_checkpoint(self): pipeline = WanPipeline.from_pretrained(self.config) return pipeline - def load_checkpoint(self, step=None): + def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py new file mode 100644 index 000000000..a588aa3d3 --- /dev/null +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -0,0 +1,122 @@ +""" + Copyright 2025 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import unittest +from unittest.mock import patch, MagicMock + +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT + +class WanCheckpointerTest(unittest.TestCase): + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" + self.config.dataset_type = "test_dataset" + + @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') + @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') + def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') + @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {'params': {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ['wan_state', 'wan_config'] + def getitem_side_effect(key): + if key == 'wan_state': + return restored_mock.wan_state + raise KeyError(key) + restored_mock.__getitem__.side_effect = getitem_side_effect + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with( + directory=unittest.mock.ANY, + step=1, + args=unittest.mock.ANY + ) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') + @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') + def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {'params': {}, 'opt_state': {'learning_rate': 0.001}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ['wan_state', 'wan_config'] + def getitem_side_effect(key): + if key == 'wan_state': + return restored_mock.wan_state + raise KeyError(key) + restored_mock.__getitem__.side_effect = getitem_side_effect + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with( + directory=unittest.mock.ANY, + step=1, + args=unittest.mock.ANY + ) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state['learning_rate'], 0.001) + self.assertEqual(step, 1) + +if __name__ == "__main__": + unittest.main()