diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index b8dd4ed9c..24c7b2ffd 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -17,15 +17,15 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" -from typing import Optional, Any +from typing import Optional, Tuple import jax import numpy as np import os - import orbax.checkpoint from maxdiffusion import max_logging from etils import epath from flax.training import train_state +from flax.traverse_util import flatten_dict, unflatten_dict import orbax import orbax.checkpoint as ocp from orbax.checkpoint.logging import AbstractLogger @@ -34,6 +34,7 @@ STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT" STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT" FLUX_CHECKPOINT = "FLUX_CHECKPOINT" +WAN_CHECKPOINT = "WAN_CHECKPOINT" def create_orbax_checkpoint_manager( @@ -59,6 +60,8 @@ def create_orbax_checkpoint_manager( if checkpoint_type == FLUX_CHECKPOINT: item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") + elif checkpoint_type == WAN_CHECKPOINT: + item_names = ("wan_state", "wan_config") else: item_names = ( "unet_config", @@ -78,7 +81,7 @@ def create_orbax_checkpoint_manager( if dataset_type == "grain": item_names += ("iter",) - print("item_names: ", item_names) + max_logging.log(f"item_names: {item_names}") mngr = CheckpointManager( p, @@ -133,6 +136,7 @@ def load_params_from_path( unboxed_abstract_params, checkpoint_item: str, step: Optional[int] = None, + checkpoint_item_config: Optional[str] = None, ): ckptr = ocp.PyTreeCheckpointer() diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 8f1e2654e..1cd842f67 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,9 +15,15 @@ """ from abc import ABC +import json + +import jax +import numpy as np from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils +import orbax.checkpoint as ocp +from etils import epath WAN_CHECKPOINT = "WAN_CHECKPOINT" @@ -28,7 +34,7 @@ def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type - self.checkpoint_manager = create_orbax_checkpoint_manager( + self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, @@ -44,22 +50,134 @@ def _create_optimizer(self, model, config, learning_rate): return tx, learning_rate_scheduler def load_wan_configs_from_orbax(self, step): - max_logging.log("Restoring stable diffusion configs") if step is None: step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") if step is None: return None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + transformer_metadata = metadatas.wan_state + abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) + params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_params, + ) + ) + + max_logging.log("Restoring WAN checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + wan_state=params_restore, + # wan_state=params_restore_util_way, + wan_config=ocp.args.JsonRestore(), + ), + ) + return restored_checkpoint def load_diffusers_checkpoint(self): pipeline = WanPipeline.from_pretrained(self.config) return pipeline def load_checkpoint(self, step=None): - model_configs = self.load_wan_configs_from_orbax(step) + restored_checkpoint = self.load_wan_configs_from_orbax(step) - if model_configs: - raise NotImplementedError("model configs should not exist in orbax") + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) else: + max_logging.log("No checkpoint found, loading default pipeline.") pipeline = self.load_diffusers_checkpoint() return pipeline + + def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items["wan_state"] = ocp.args.PyTreeSave(train_states) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") + + +def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + """ + only save the config that is needed and can be serialized to JSON. + """ + if not hasattr(model_or_config, "config"): + return None + source_config = dict(model_or_config.config) + + # 1. configs that can be serialized to JSON + SAFE_KEYS = [ + "_class_name", + "_diffusers_version", + "model_type", + "patch_size", + "num_attention_heads", + "attention_head_dim", + "in_channels", + "out_channels", + "text_dim", + "freq_dim", + "ffn_dim", + "num_layers", + "cross_attn_norm", + "qk_norm", + "eps", + "image_dim", + "added_kv_proj_dim", + "rope_max_seq_len", + "pos_embed_seq_len", + "flash_min_seq_length", + "flash_block_sizes", + "attention", + "_use_default_values", + ] + + # 2. save the config that are in the SAFE_KEYS list + clean_config = {} + for key in SAFE_KEYS: + if key in source_config: + clean_config[key] = source_config[key] + + # 3. deal with special data type and precision + if "dtype" in source_config and hasattr(source_config["dtype"], "name"): + clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16' + + if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"): + clean_config["weights_dtype"] = source_config["weights_dtype"].name + + if "precision" in source_config and isinstance(source_config["precision"]): + clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST' + + return clean_config + + items_to_save = { + "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) + + # Create CompositeArgs for Orbax + save_args = ocp.args.Composite(**items_to_save) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=save_args) + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 0d3fb969b..f25538631 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -198,6 +198,7 @@ remat_policy: "NONE" # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 +checkpoint_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index bfd420f72..5d1785070 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -24,13 +24,13 @@ from collections import OrderedDict from pathlib import PosixPath from typing import Any, Dict, Tuple, Union - +from . import max_logging import numpy as np from huggingface_hub import create_repo, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError - +import jax.numpy as jnp from . import __version__ from .utils import ( DIFFUSERS_CACHE, @@ -47,6 +47,21 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") +class CustomEncoder(json.JSONEncoder): + """ + Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes. + """ + def default(self, o): + # This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16" + if isinstance(o, type(jnp.dtype('bfloat16'))): + return str(o) + # Add fallbacks for other numpy types if needed + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + # Let the base class default method raise the TypeError for other types + return super().default(o) class FrozenDict(OrderedDict): @@ -579,8 +594,25 @@ def to_json_saveable(value): config_dict.pop("precision", None) config_dict.pop("weights_dtype", None) config_dict.pop("quant", None) + keys_to_remove = [] + for key, value in config_dict.items(): + # Check the type of the value by its class name to avoid import issues + if type(value).__name__ == 'Rngs': + keys_to_remove.append(key) + + if keys_to_remove: + max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}") + for key in keys_to_remove: + config_dict.pop(key) + + try: + + json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder) + except Exception as e: + max_logging.log(f"Error serializing config to JSON: {e}") + raise e - return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + return json_str + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index a9bcf366c..519bc8cb3 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -25,6 +25,10 @@ def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) + from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + + checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") + pipeline = checkpoint_loader.load_checkpoint() if pipeline is None: pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 1659d3bb5..c9d3cf9df 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -66,14 +66,19 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. -def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): +def create_sharded_logical_transformer( + devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None +): def create_model(rngs: nnx.Rngs, wan_config: dict): wan_transformer = WanModel(**wan_config, rngs=rngs) return wan_transformer # 1. Load config. - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") + if restored_checkpoint: + wan_config = restored_checkpoint["wan_config"] + else: + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype @@ -99,11 +104,16 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # 4. Load pretrained weights and move them to device using the state shardings from (3) above. # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_transformer( - config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] - ) + if restored_checkpoint: + params = restored_checkpoint["wan_state"] + else: + params = load_wan_transformer( + config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] + ) params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): + if restored_checkpoint: + path = path[:-1] sharding = logical_state_sharding[path].value state[path].value = device_put_replicated(val, sharding) state = nnx.from_flat_state(state) @@ -295,9 +305,13 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline return quantized_model @classmethod - def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + def load_transformer( + cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None + ): with mesh: - wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + wan_transformer = create_sharded_logical_transformer( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + ) return wan_transformer @classmethod @@ -309,6 +323,45 @@ def load_scheduler(cls, config): ) return scheduler, scheduler_state + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + transformer = None + tokenizer = None + scheduler = None + scheduler_state = None + text_encoder = None + if not vae_only: + if load_transformer: + with mesh: + transformer = cls.load_transformer( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + ) + + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) + + scheduler, scheduler_state = cls.load_scheduler(config=config) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + return WanPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=wan_vae, + vae_cache=vae_cache, + scheduler=scheduler, + scheduler_state=scheduler_state, + devices_array=devices_array, + mesh=mesh, + config=config, + ) + @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): devices_array = max_utils.create_device_mesh(config) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a267e0653..7090ad118 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -224,7 +224,6 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint. start_step = 0 per_device_tflops = self.calculate_tflops(pipeline) - scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config) with ThreadPoolExecutor(max_workers=1) as executor: @@ -274,12 +273,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data else: max_logging.log(f"Step {step}, evaluation dataset was empty.") example_batch = next_batch_future.result() + if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: + max_logging.log(f"Saving checkpoint for step {step}") + self.save_checkpoint(step, pipeline, state.params) _metrics_queue.put(None) writer_thread.join() if writer: writer.flush() - + if self.config.save_final_checkpoint: + max_logging.log(f"Saving final checkpoint for step {step}") + self.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) + self.checkpoint_manager.wait_until_finished() # load new state for trained tranformer pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) return pipeline