From feebbb10e424797f29dda710682b44565b9deccd Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Wed, 22 Oct 2025 21:49:09 +0000 Subject: [PATCH 1/2] Add ability to save optimizer and resume while training --- .../checkpointing/wan_checkpointer.py | 17 +++++---- src/maxdiffusion/configs/base_wan_14b.yml | 1 + .../pipelines/wan/wan_pipeline.py | 5 ++- src/maxdiffusion/pyconfig.py | 2 -- src/maxdiffusion/trainers/wan_trainer.py | 35 ++++++++++++++----- 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 1cd842f67..a1d029a43 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -33,6 +33,7 @@ class WanCheckpointer(ABC): def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type + self.opt_state = None self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( self.config.checkpoint_dir, @@ -57,7 +58,6 @@ def load_wan_configs_from_orbax(self, step): return None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) - transformer_metadata = metadatas.wan_state abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) params_restore = ocp.args.PyTreeRestore( @@ -73,27 +73,32 @@ def load_wan_configs_from_orbax(self, step): step=step, args=ocp.args.Composite( wan_state=params_restore, - # wan_state=params_restore_util_way, wan_config=ocp.args.JsonRestore(), ), ) - return restored_checkpoint + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step def load_diffusers_checkpoint(self): pipeline = WanPipeline.from_pretrained(self.config) return pipeline def load_checkpoint(self, step=None): - restored_checkpoint = self.load_wan_configs_from_orbax(step) - + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None if restored_checkpoint: max_logging.log("Loading WAN pipeline from checkpoint") pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint["wan_state"].keys(): + opt_state = restored_checkpoint["wan_state"]["opt_state"] else: max_logging.log("No checkpoint found, loading default pipeline.") pipeline = self.load_diffusers_checkpoint() - return pipeline + return pipeline, opt_state, step def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 4a9730454..46285dd85 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -242,6 +242,7 @@ num_eval_samples: 420 warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 115c90545..3e7ce7bf5 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -131,7 +131,10 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. if restored_checkpoint: - params = restored_checkpoint["wan_state"] + if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer + params = restored_checkpoint["wan_state"]["params"] + else: # if not checkpointed with optimizer + params = restored_checkpoint["wan_state"] else: params = load_wan_transformer( config.wan_transformer_pretrained_model_name_or_path, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 3bb5bd13c..56eeae766 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -196,8 +196,6 @@ def user_init(raw_keys): # Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path raw_keys["tokenizer_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - if "gs://" in raw_keys["tokenizer_model_name_or_path"]: - raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") if "gs://" in raw_keys["pretrained_model_name_or_path"]: raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") if "gs://" in raw_keys["unet_checkpoint"]: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index d6a0cc803..89981f1ad 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -17,6 +17,7 @@ import os import datetime import functools +from pprint import pprint import numpy as np import threading from concurrent.futures import ThreadPoolExecutor @@ -209,7 +210,11 @@ def prepare_sample_eval(features): def start_training(self): - pipeline = self.load_checkpoint() + pipeline, opt_state, step = self.load_checkpoint() + restore_args = {} + if opt_state and step: + restore_args = {"opt_state": opt_state, "step":step} + del opt_state if self.config.enable_ssim: # Generate a sample before training to compare against generated sample after training. pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") @@ -228,7 +233,7 @@ def start_training(self): pipeline.scheduler_state = scheduler_state optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) # Returns pipeline with trained transformer state - pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) + pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) if self.config.enable_ssim: posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") @@ -280,18 +285,28 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr if writer: writer.add_scalar("learning/eval_loss", final_eval_loss, step) - def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args:dict={}): mesh = pipeline.mesh graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): state = TrainState.create( - apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state - ) + apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state) + if restore_args: + step = restore_args.get("step", 0) + max_logging.log(f"Restoring optimizer and resuming from step {step}") + state.replace(opt_state=restore_args.get("opt_state"), step = restore_args.get("step", 0)) + del restore_args["opt_state"] + del optimizer state = jax.tree.map(_to_array, state) state_spec = nnx.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) state_shardings = nnx.get_named_sharding(state, mesh) + if jax.process_index() == 0 and restore_args: + max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---") + pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) + max_logging.log(pretty_string) + max_logging.log("------------------------------------------------") data_shardings = self.get_data_shardings(mesh) eval_data_shardings = self.get_eval_data_shardings(mesh) @@ -334,8 +349,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data last_profiling_step = np.clip( first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 ) - # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint. - start_step = 0 + if restore_args.get("step",0): + max_logging.log(f"Resuming training from step {step}") + start_step = restore_args.get("step",0) per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline) scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config) @@ -373,7 +389,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data example_batch = next_batch_future.result() if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") - self.save_checkpoint(step, pipeline, state.params) + if self.config.save_optimizer: + self.save_checkpoint(step, pipeline, state) + else: + self.save_checkpoint(step, pipeline, state.params) _metrics_queue.put(None) writer_thread.join() From e8d8ccfb6cfadffd2793e9fff1bc58cf498e9ccd Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Thu, 23 Oct 2025 20:47:40 +0000 Subject: [PATCH 2/2] Bug fix, case where not checkpoint load from diffusers --- .github/workflows/UnitTests.yml | 2 +- .../checkpointing/wan_checkpointer.py | 8 +- .../tests/wan_checkpointer_test.py | 122 ++++++++++++++++++ 3 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 src/maxdiffusion/tests/wan_checkpointer_test.py diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2c588b439..1512485b6 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -58,7 +58,7 @@ jobs: pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - name: PyTest run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index a1d029a43..0dd493a33 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -19,6 +19,7 @@ import jax import numpy as np +from typing import Optional, Tuple from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils @@ -50,12 +51,13 @@ def _create_optimizer(self, model, config, learning_rate): tx = max_utils.create_optimizer(config, learning_rate_scheduler) return tx, learning_rate_scheduler - def load_wan_configs_from_orbax(self, step): + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: step = self.checkpoint_manager.latest_step() max_logging.log(f"Latest WAN checkpoint step: {step}") if step is None: - return None + max_logging.log("No WAN checkpoint found.") + return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) transformer_metadata = metadatas.wan_state @@ -86,7 +88,7 @@ def load_diffusers_checkpoint(self): pipeline = WanPipeline.from_pretrained(self.config) return pipeline - def load_checkpoint(self, step=None): + def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py new file mode 100644 index 000000000..a588aa3d3 --- /dev/null +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -0,0 +1,122 @@ +""" + Copyright 2025 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import unittest +from unittest.mock import patch, MagicMock + +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT + +class WanCheckpointerTest(unittest.TestCase): + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" + self.config.dataset_type = "test_dataset" + + @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') + @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') + def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') + @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {'params': {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ['wan_state', 'wan_config'] + def getitem_side_effect(key): + if key == 'wan_state': + return restored_mock.wan_state + raise KeyError(key) + restored_mock.__getitem__.side_effect = getitem_side_effect + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with( + directory=unittest.mock.ANY, + step=1, + args=unittest.mock.ANY + ) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') + @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') + def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {'params': {}, 'opt_state': {'learning_rate': 0.001}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ['wan_state', 'wan_config'] + def getitem_side_effect(key): + if key == 'wan_state': + return restored_mock.wan_state + raise KeyError(key) + restored_mock.__getitem__.side_effect = getitem_side_effect + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with( + directory=unittest.mock.ANY, + step=1, + args=unittest.mock.ANY + ) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state['learning_rate'], 0.001) + self.assertEqual(step, 1) + +if __name__ == "__main__": + unittest.main()