From 996388ae2098ce4df1a33138cb18c3faedab32a6 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 10:44:35 +0530 Subject: [PATCH 01/28] Changes for WAN 2.2 --- .../checkpointing/wan_checkpointer.py | 49 ++- src/maxdiffusion/configs/base_wan_27b.yml | 332 ++++++++++++++++++ src/maxdiffusion/generate_wan.py | 52 ++- src/maxdiffusion/models/wan/wan_utils.py | 9 +- .../pipelines/wan/wan_pipeline.py | 164 ++++++--- 5 files changed, 532 insertions(+), 74 deletions(-) create mode 100644 src/maxdiffusion/configs/base_wan_27b.yml diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 0dd493a33..1f8db8f70 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -34,7 +34,7 @@ class WanCheckpointer(ABC): def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type - self.opt_state = None + self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( self.config.checkpoint_dir, @@ -60,23 +60,36 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) - transformer_metadata = metadatas.wan_state - abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) - params_restore = ocp.args.PyTreeRestore( + + restore_args = {} + + low_state_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata) + low_state_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_params, + abstract_tree_structure_low_state, ) ) + restore_args["low_noise_transformer_state"] = low_state_restore + + if self.run_wan2_2: + high_state_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata) + high_state_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_state, + ) + ) + restore_args["high_noise_transformer_state"] = high_state_restore + + restore_args["wan_config"] = ocp.args.JsonRestore() max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), step=step, - args=ocp.args.Composite( - wan_state=params_restore, - wan_config=ocp.args.JsonRestore(), - ), + args=ocp.args.Composite(**restore_args), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") @@ -110,14 +123,22 @@ def config_to_json(model_or_config): max_logging.log(f"Saving checkpoint for step {train_step}") items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), } - items["wan_state"] = ocp.args.PyTreeSave(train_states) + if "low_noise_transformer" in train_states: + low_noise_state = train_states["low_noise_transformer"] + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state) + if self.run_wan2_2: + if "high_noise_transformer" in train_states: + high_noise_state = train_states["high_noise_transformer"] + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state) + # Save the checkpoint - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") + if len(items) > 1: + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml new file mode 100644 index 000000000..81fc5914a --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -0,0 +1,332 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' +run_wan2_2: True + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: True + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 +dropout: 0.1 + +flash_block_sizes: { + "block_q" : 1024, + "block_kv_compute" : 256, + "block_kv" : 1024, + "block_q_dkv" : 1024, + "block_kv_dkv" : 1024, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 1024, + "block_kv_dq" : 1024 +} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048 +# "use_fused_bwd_kernel": False, +# } +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', 'data'], + ['activation_length', 'fsdp'], + + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 480 +width: 832 +num_frames: 81 +flow_shift: 3.0 + +guidance_scale_low: 5.0 +guidance_scale_high: 8.0 +boundary_timestep: 15 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 24 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 501dbf32e..46fca5e87 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -19,10 +19,23 @@ from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from maxdiffusion import pyconfig, max_logging, max_utils from absl import app +from absl import flags from maxdiffusion.utils import export_to_video from google.cloud import storage import flax +_MODEL_NAME = flags.DEFINE_enum( + "model_name", + default="wan2.1", + enum_values=["wan2.1", "wan2.2"], + help="The model version to run (wan2.1 or wan2.2). This determines the base config file.", +) + +CONFIG_BASE_DIR = "src/maxdiffusion/configs" +MODEL_CONFIG_MAP = { + "wan2.1": "base_wan_14b.yml", + "wan2.2": "base_wan_27b.yml", +} def upload_video_to_gcs(output_dir: str, video_path: str): """ @@ -80,7 +93,10 @@ def inference_generate_video(config, pipeline, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}") @@ -107,6 +123,10 @@ def run(config, pipeline=None, filename_prefix=""): # Using global_batch_size_to_train_on so not to create more config variables prompt = [config.prompt] * config.global_batch_size_to_train_on negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on + guidance_scale = config.guidance_scale if 'guidance_scale' in config.__dict__ else 5 + guidance_scale_low = config.guidance_scale_low if 'guidance_scale_low' in config.__dict__ else 3 + guidance_scale_high = config.guidance_scale_high if 'guidance_scale_high' in config.__dict__ else 4 + boundary = config.boundary_timestep if 'boundary_timestep' in config.__dict__ else 875 max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" @@ -119,7 +139,10 @@ def run(config, pipeline=None, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) print("compile time: ", (time.perf_counter() - s0)) @@ -139,7 +162,10 @@ def run(config, pipeline=None, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) print("generation time: ", (time.perf_counter() - s0)) @@ -153,7 +179,10 @@ def run(config, pipeline=None, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) max_utils.deactivate_profiler(config) print("generation time: ", (time.perf_counter() - s0)) @@ -161,7 +190,20 @@ def run(config, pipeline=None, filename_prefix=""): def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) + # Get the model name from the flag + model_key = _MODEL_NAME.value + config_filename = MODEL_CONFIG_MAP[model_key] + selected_yaml_path = os.path.join(CONFIG_BASE_DIR, config_filename) + + max_logging.log(f"Using model: {model_key}, loading base config: {selected_yaml_path}") + + # Construct argv for pyconfig.initialize + # argv[0] is the program name. + # Insert the selected YAML path at index 1. + # The rest of argv (argv[1:]) are the overrides. + argv_for_pyconfig = list(argv[:1]) + [selected_yaml_path] + list(argv[1:]) + + pyconfig.initialize(argv_for_pyconfig) flax.config.update("flax_always_shard_variable", False) run(pyconfig.config) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index ec97abd30..191d8b617 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -184,6 +184,7 @@ def load_wan_transformer( hf_download: bool = True, num_layers: int = 40, scan_layers: bool = True, + subfolder: str = "", ): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: @@ -192,7 +193,7 @@ def load_wan_transformer( return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) else: return load_base_wan_transformer( - pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers + pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder ) @@ -203,9 +204,9 @@ def load_base_wan_transformer( hf_download: bool = True, num_layers: int = 40, scan_layers: bool = True, + subfolder: str = "", ): device = jax.local_devices(backend=device)[0] - subfolder = "transformer" filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False if os.path.isdir(pretrained_model_name_or_path): @@ -236,7 +237,7 @@ def load_base_wan_transformer( else: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) # now get all the filenames for the model that need downloading - max_logging.log(f"Load and port Wan 2.1 transformer on {device}") + max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}") if ckpt_shard_path is not None: with safe_open(ckpt_shard_path, framework="pt") as f: @@ -281,7 +282,7 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: raise FileNotFoundError(f"File {ckpt_path} not found for local directory.") elif hf_download: ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename) - max_logging.log(f"Load and port Wan 2.1 VAE on {device}") + max_logging.log(f"Load and port {pretrained_model_name_or_path} VAE on {device}") with jax.default_device(device): if ckpt_path is not None: tensors = {} diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 55981be0b..3f3438459 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -89,7 +89,7 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None + devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" ): def create_model(rngs: nnx.Rngs, wan_config: dict): @@ -100,7 +100,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): if restored_checkpoint: wan_config = restored_checkpoint["wan_config"] else: - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype @@ -142,6 +142,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): "cpu", num_layers=wan_config["num_layers"], scan_layers=config.scan_layers, + subfolder=subfolder, ) params = jax.tree_util.tree_map_with_path( @@ -191,7 +192,8 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanModel, + low_noise_transformer: WanModel, + high_noise_transformer: Optional[WanModel], vae: AutoencoderKLWan, vae_cache: AutoencoderKLWanCache, scheduler: FlaxUniPCMultistepScheduler, @@ -202,7 +204,8 @@ def __init__( ): self.tokenizer = tokenizer self.text_encoder = text_encoder - self.transformer = transformer + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer self.vae = vae self.vae_cache = vae_cache self.scheduler = scheduler @@ -210,6 +213,7 @@ def __init__( self.devices_array = devices_array self.mesh = mesh self.config = config + self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -353,11 +357,10 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline @classmethod def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None - ): + cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): with mesh: wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder ) return wan_transformer @@ -376,7 +379,9 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - transformer = None + run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + low_noise_transformer = None + high_noise_transformer = None tokenizer = None scheduler = None scheduler_state = None @@ -384,9 +389,9 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint - ) + low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") + if run_wan2_2: + high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -399,7 +404,8 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ return WanPipeline( tokenizer=tokenizer, text_encoder=text_encoder, - transformer=transformer, + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, vae=wan_vae, vae_cache=vae_cache, scheduler=scheduler, @@ -415,7 +421,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - transformer = None + run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + low_noise_transformer = None + high_noise_transformer = None tokenizer = None scheduler = None scheduler_state = None @@ -423,8 +431,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - + low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") + if run_wan2_2: + high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -436,7 +445,8 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform pipeline = WanPipeline( tokenizer=tokenizer, text_encoder=text_encoder, - transformer=transformer, + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, vae=wan_vae, vae_cache=vae_cache, scheduler=scheduler, @@ -446,7 +456,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform config=config, ) - pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh) + pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) + if run_wan2_2: + pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) return pipeline def _get_t5_prompt_embeds( @@ -546,6 +558,9 @@ def __call__( num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, num_videos_per_prompt: Optional[int] = 1, max_sequence_length: int = 512, latents: jax.Array = None, @@ -575,7 +590,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - num_channel_latents = self.transformer.config.in_channels + num_channel_latents = self.low_noise_transformer.config.in_channels if latents is None: latents = self.prepare_latents( batch_size=batch_size, @@ -600,22 +615,31 @@ def __call__( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) - graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = None, None, None + if self.run_wan2_2: + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) p_run_inference = partial( run_inference, + run_wan2_2=self.run_wan2_2, guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, - num_transformer_layers=self.transformer.config.num_layers, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( - graphdef=graphdef, - sharded_state=state, - rest_of_state=rest_of_state, + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -635,43 +659,74 @@ def __call__( return video -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) +@partial(jax.jit, static_argnames=("run_wan2_2", "guidance_scale", "guidance_scale_low", "guidance_scale_high", "boundary", "do_classifier_free_guidance")) def transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents, timestep, prompt_embeds, - do_classifier_free_guidance, - guidance_scale, + run_wan2_2: bool, + guidance_scale: float, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, + do_classifier_free_guidance: bool, + t: jnp.array, ): - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - if do_classifier_free_guidance: - bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] + low_noise_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + noise_pred_low = low_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred_low + current_guide_scale = guidance_scale + if run_wan2_2: + high_noise_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + noise_pred_high = high_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + use_high_noise = jnp.greater_equal(t, boundary) + noise_pred = jax.lax.cond( + use_high_noise, + lambda: noise_pred_high, + lambda: noise_pred_low, + ) + current_guide_scale = jax.lax.cond( + use_high_noise, + lambda: guidance_scale_high, + lambda: guidance_scale_low, + ) - return noise_pred, latents + if do_classifier_free_guidance: + bsz = latents.shape[0] // 2 + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] + noise_pred = noise_uncond + current_guide_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] + return noise_pred, latents def run_inference( - graphdef, - sharded_state, - rest_of_state, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, + run_wan2_2: bool, guidance_scale: float, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, - num_transformer_layers: int, scheduler_state, ): do_classifier_free_guidance = guidance_scale > 1.0 + if run_wan2_2: + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) for step in range(num_inference_steps): @@ -681,14 +736,21 @@ def run_inference( timestep = jnp.broadcast_to(t, latents.shape[0]) noise_pred, latents = transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents, timestep, prompt_embeds, - do_classifier_free_guidance=do_classifier_free_guidance, - guidance_scale=guidance_scale, + run_wan2_2, + guidance_scale, + guidance_scale_low, + guidance_scale_high, + boundary, + do_classifier_free_guidance, + t ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() From c094a73bfd0a88f8c1add6fe9b85118fadd79950 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 11:03:04 +0530 Subject: [PATCH 02/28] changes return type of checkpoint_loader to tuple --- src/maxdiffusion/generate_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 46fca5e87..c4b53b8db 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -115,7 +115,7 @@ def run(config, pipeline=None, filename_prefix=""): from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") - pipeline = checkpoint_loader.load_checkpoint() + pipeline, opt_state, step = checkpoint_loader.load_checkpoint() if pipeline is None: pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() From 33bf49c328e9d92a7e50180b1435191a13c2b624 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 18:42:20 +0530 Subject: [PATCH 03/28] opt_state=None added --- src/maxdiffusion/checkpointing/wan_checkpointer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 1f8db8f70..4295e6333 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -34,6 +34,7 @@ class WanCheckpointer(ABC): def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type + self.opt_state = None self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( From 8a752e78d711e5b871baeddc2c2700696834b707 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 21:10:39 +0530 Subject: [PATCH 04/28] added model_name in config file --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/configs/base_wan_27b.yml | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 50e66964e..8dea4e3a2 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -28,6 +28,7 @@ save_config_to_gcs: False log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' +model_name: wan2.1 # Overrides the transformer from pretrained_model_name_or_path wan_transformer_pretrained_model_name_or_path: '' diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 81fc5914a..6d005bddb 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -28,7 +28,7 @@ save_config_to_gcs: False log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' -run_wan2_2: True +model_name: wan2.2 # Overrides the transformer from pretrained_model_name_or_path wan_transformer_pretrained_model_name_or_path: '' diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 3f3438459..69833842f 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -213,7 +213,7 @@ def __init__( self.devices_array = devices_array self.mesh = mesh self.config = config - self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False + self.run_wan2_2 = config.model_name == "wan2.2" self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + run_wan2_2 = config.model_name == "wan2.2" low_noise_transformer = None high_noise_transformer = None tokenizer = None @@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + run_wan2_2 = config.model_name == "wan2.2" low_noise_transformer = None high_noise_transformer = None tokenizer = None From 1be0361b80702bc7cdba4dc022b489ca94bf4595 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 27 Oct 2025 15:04:47 +0530 Subject: [PATCH 05/28] double noise computation fixed --- .../pipelines/wan/wan_pipeline.py | 118 ++++++++---------- 1 file changed, 51 insertions(+), 67 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 69833842f..7e62a236e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -213,7 +213,7 @@ def __init__( self.devices_array = devices_array self.mesh = mesh self.config = config - self.run_wan2_2 = config.model_name == "wan2.2" + self.model_name = config.model_name self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.model_name == "wan2.2" + model_name = config.model_name low_noise_transformer = None high_noise_transformer = None tokenizer = None @@ -390,7 +390,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ if load_transformer: with mesh: low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") - if run_wan2_2: + if model_name == "wan2.2": high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) @@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.model_name == "wan2.2" + model_name = config.model_name low_noise_transformer = None high_noise_transformer = None tokenizer = None @@ -432,7 +432,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform if load_transformer: with mesh: low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - if run_wan2_2: + if model_name == "wan2.2": high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -457,7 +457,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform ) pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) - if run_wan2_2: + if model_name == "wan2.2": pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) return pipeline @@ -617,12 +617,12 @@ def __call__( low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) high_noise_graphdef, high_noise_state, high_noise_rest = None, None, None - if self.run_wan2_2: + if self.model_name == "wan2.2": high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) p_run_inference = partial( run_inference, - run_wan2_2=self.run_wan2_2, + model_name=self.model_name, guidance_scale=guidance_scale, guidance_scale_low=guidance_scale_low, guidance_scale_high=guidance_scale_high, @@ -659,51 +659,27 @@ def __call__( return video -@partial(jax.jit, static_argnames=("run_wan2_2", "guidance_scale", "guidance_scale_low", "guidance_scale_high", "boundary", "do_classifier_free_guidance")) +@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents, timestep, + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, prompt_embeds, - run_wan2_2: bool, - guidance_scale: float, - guidance_scale_low: float, - guidance_scale_high: float, - boundary: int, - do_classifier_free_guidance: bool, - t: jnp.array, + do_classifier_free_guidance, + guidance_scale, ): - low_noise_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) - noise_pred_low = low_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred_low - current_guide_scale = guidance_scale - if run_wan2_2: - high_noise_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) - noise_pred_high = high_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - use_high_noise = jnp.greater_equal(t, boundary) - noise_pred = jax.lax.cond( - use_high_noise, - lambda: noise_pred_high, - lambda: noise_pred_low, - ) - current_guide_scale = jax.lax.cond( - use_high_noise, - lambda: guidance_scale_high, - lambda: guidance_scale_low, - ) - - if do_classifier_free_guidance: - bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + current_guide_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + if do_classifier_free_guidance: + bsz = latents.shape[0] // 2 + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] - return noise_pred, latents + return noise_pred, latents def run_inference( low_noise_graphdef, @@ -715,7 +691,7 @@ def run_inference( latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, - run_wan2_2: bool, + model_name: str, guidance_scale: float, guidance_scale_low: float, guidance_scale_high: float, @@ -725,32 +701,40 @@ def run_inference( scheduler_state, ): do_classifier_free_guidance = guidance_scale > 1.0 - if run_wan2_2: + if model_name == "wan2.2": do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + def low_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + low_noise_graphdef, low_noise_state, low_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_low + ) + + def high_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + high_noise_graphdef, high_noise_state, high_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_high + ) + for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if do_classifier_free_guidance: latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) - noise_pred, latents = transformer_forward_pass( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents, timestep, - prompt_embeds, - run_wan2_2, - guidance_scale, - guidance_scale_low, - guidance_scale_high, - boundary, - do_classifier_free_guidance, - t + use_high_noise = jnp.greater_equal(t, boundary) + + noise_pred, latents = jax.lax.cond( + use_high_noise, + high_noise_branch, + low_noise_branch, + (latents, timestep, prompt_embeds) ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() From 731b07bb9e21e8ef5a949d2d1462ea1c6da43bab Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 27 Oct 2025 15:14:10 +0530 Subject: [PATCH 06/28] support for wan2.1 in run_inference added --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7e62a236e..e6d3df1dc 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -728,6 +728,11 @@ def high_noise_branch(operands): latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) + if model_name == "wan2.1": + noise_pred, latents = low_noise_branch((latents, timestep, prompt_embeds)) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + continue + use_high_noise = jnp.greater_equal(t, boundary) noise_pred, latents = jax.lax.cond( From 11d30fce2ba51a2d938e7848e601ec513ae63904 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 12 Nov 2025 00:08:27 +0530 Subject: [PATCH 07/28] Support for WAN 2.2 added --- README.md | 19 +- .../checkpointing/checkpointing_utils.py | 2 +- .../checkpointing/wan_checkpointer.py | 209 +++++-- src/maxdiffusion/configs/base_wan_27b.yml | 6 +- src/maxdiffusion/generate_wan.py | 123 ++--- .../pipelines/wan/wan_pipeline.py | 508 ++++++++++++------ .../tests/wan_checkpointer_test.py | 224 +++++++- 7 files changed, 766 insertions(+), 325 deletions(-) diff --git a/README.md b/README.md index 7d26dd5d7..2f33c6805 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported - **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported. - **`2025/10/14`**: NVIDIA DGX Spark Flux support. - **`2025/8/14`**: LTX-Video img2vid generation is now supported. @@ -481,7 +482,23 @@ To generate images, run the following command: ```bash HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + ``` + ## Wan2.2 + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + ``` + ## Wan2.2 + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 ``` ## Flux diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 24c7b2ffd..bbad3ad1d 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -61,7 +61,7 @@ def create_orbax_checkpoint_manager( if checkpoint_type == FLUX_CHECKPOINT: item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") elif checkpoint_type == WAN_CHECKPOINT: - item_names = ("wan_state", "wan_config") + item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config") else: item_names = ( "unet_config", diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 4295e6333..74710f4fd 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -14,35 +14,50 @@ limitations under the License. """ -from abc import ABC +from abc import ABC, abstractmethod import json import jax import numpy as np -from typing import Optional, Tuple +from typing import Optional, Tuple, Type from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) -from ..pipelines.wan.wan_pipeline import WanPipeline +from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2 from .. import max_logging, max_utils import orbax.checkpoint as ocp from etils import epath + WAN_CHECKPOINT = "WAN_CHECKPOINT" class WanCheckpointer(ABC): + _SUBCLASS_MAP: dict[str, Type['WanCheckpointer']] = {} + + def __new__(cls, model_key: str, config, checkpoint_type: str = WAN_CHECKPOINT): + if cls is WanCheckpointer: + subclass = cls._SUBCLASS_MAP.get(model_key) + if subclass is None: + raise ValueError( + f"Unknown model_key: '{model_key}'. " + f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}" + ) + return super().__new__(subclass) + else: + return super().__new__(cls) - def __init__(self, config, checkpoint_type): + def __init__(self, model_key, config, checkpoint_type: str = WAN_CHECKPOINT): self.config = config self.checkpoint_type = checkpoint_type self.opt_state = None - self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False - - self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, + + self.checkpoint_manager: ocp.CheckpointManager = ( + create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, + ) ) def _create_optimizer(self, model, config, learning_rate): @@ -52,6 +67,25 @@ def _create_optimizer(self, model, config, learning_rate): tx = max_utils.create_optimizer(config, learning_rate_scheduler) return tx, learning_rate_scheduler + @abstractmethod + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + raise NotImplementedError + + @abstractmethod + def load_diffusers_checkpoint(self): + raise NotImplementedError + + @abstractmethod + def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]: + raise NotImplementedError + + @abstractmethod + def save_checkpoint(self, train_step, pipeline, train_states: dict): + raise NotImplementedError + + +class WanCheckpointer2_1(WanCheckpointer): + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: step = self.checkpoint_manager.latest_step() @@ -61,36 +95,23 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) - - restore_args = {} - - low_state_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata) - low_state_restore = ocp.args.PyTreeRestore( + transformer_metadata = metadatas.wan_state + abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) + params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_state, + abstract_tree_structure_params, ) ) - restore_args["low_noise_transformer_state"] = low_state_restore - - if self.run_wan2_2: - high_state_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata) - high_state_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_state, - ) - ) - restore_args["high_noise_transformer_state"] = high_state_restore - - restore_args["wan_config"] = ocp.args.JsonRestore() max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), step=step, - args=ocp.args.Composite(**restore_args), + args=ocp.args.Composite( + wan_state=params_restore, + wan_config=ocp.args.JsonRestore(), + ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") @@ -99,24 +120,113 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return restored_checkpoint, step def load_diffusers_checkpoint(self): - pipeline = WanPipeline.from_pretrained(self.config) + pipeline = WanPipeline2_1.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint.wan_state.keys(): + opt_state = restored_checkpoint.wan_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items["wan_state"] = ocp.args.PyTreeSave(train_states) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") + + +class WanCheckpointer2_2(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + # Handle low_noise_transformer + low_noise_transformer_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + low_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_low_params, + ) + ) + + # Handle high_noise_transformer + high_noise_transformer_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + high_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_params, + ) + ) + + max_logging.log("Restoring WAN 2.2 checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + low_noise_transformer_state=low_params_restore, + high_noise_transformer_state=high_params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline2_2.from_pretrained(self.config) return pipeline - def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) - if "opt_state" in restored_checkpoint["wan_state"].keys(): - opt_state = restored_checkpoint["wan_state"]["opt_state"] + pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) + # Check for optimizer state in either transformer + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] else: max_logging.log("No checkpoint found, loading default pipeline.") pipeline = self.load_diffusers_checkpoint() return pipeline, opt_state, step - def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): + def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict): """Saves the training state and model configurations.""" def config_to_json(model_or_config): @@ -127,22 +237,17 @@ def config_to_json(model_or_config): "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), } - if "low_noise_transformer" in train_states: - low_noise_state = train_states["low_noise_transformer"] - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state) + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) - if self.run_wan2_2: - if "high_noise_transformer" in train_states: - high_noise_state = train_states["high_noise_transformer"] - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state) - # Save the checkpoint - if len(items) > 1: - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") +WanCheckpointer._SUBCLASS_MAP["wan2.1"] = WanCheckpointer2_1 +WanCheckpointer._SUBCLASS_MAP["wan2.2"] = WanCheckpointer2_2 -def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): +def save_checkpoint_orig(self, train_step, pipeline, train_states: dict): """Saves the training state and model configurations.""" def config_to_json(model_or_config): diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 6d005bddb..323a1a51c 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -272,9 +272,9 @@ width: 832 num_frames: 81 flow_shift: 3.0 -guidance_scale_low: 5.0 -guidance_scale_high: 8.0 -boundary_timestep: 15 +guidance_scale_low: 3.0 +guidance_scale_high: 4.0 +boundary_timestep: 875 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index c4b53b8db..53a38ac45 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -17,25 +17,13 @@ import time import os from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer from maxdiffusion import pyconfig, max_logging, max_utils from absl import app -from absl import flags from maxdiffusion.utils import export_to_video from google.cloud import storage import flax -_MODEL_NAME = flags.DEFINE_enum( - "model_name", - default="wan2.1", - enum_values=["wan2.1", "wan2.2"], - help="The model version to run (wan2.1 or wan2.2). This determines the base config file.", -) - -CONFIG_BASE_DIR = "src/maxdiffusion/configs" -MODEL_CONFIG_MAP = { - "wan2.1": "base_wan_14b.yml", - "wan2.2": "base_wan_27b.yml", -} def upload_video_to_gcs(output_dir: str, video_path: str): """ @@ -76,6 +64,33 @@ def delete_file(file_path: str): jax.config.update("jax_use_shardy_partitioner", True) +def call_pipeline(config, pipeline, prompt, negative_prompt): + model_key = config.model_name + if model_key == "wan2.1": + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + elif model_key == "wan2.2": + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale_low=config.guidance_scale_low, + guidance_scale_high=config.guidance_scale_high, + boundary=config.boundary_timestep, + ) + else: + raise ValueError(f"Unsupported model_name in config: {model_key}") + def inference_generate_video(config, pipeline, filename_prefix=""): s0 = time.perf_counter() @@ -86,18 +101,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""): f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}" ) - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}") for i in range(len(videos)): @@ -112,38 +116,20 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) - from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer - - checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") - pipeline, opt_state, step = checkpoint_loader.load_checkpoint() - if pipeline is None: - pipeline = WanPipeline.from_pretrained(config) + model_key = config.model_name + checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) + pipeline, _, _ = checkpoint_loader.load_checkpoint() + pipeline = WanPipeline.from_pretrained(model_key=model_key, config=config) s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables prompt = [config.prompt] * config.global_batch_size_to_train_on negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on - guidance_scale = config.guidance_scale if 'guidance_scale' in config.__dict__ else 5 - guidance_scale_low = config.guidance_scale_low if 'guidance_scale_low' in config.__dict__ else 3 - guidance_scale_high = config.guidance_scale_high if 'guidance_scale_high' in config.__dict__ else 4 - boundary = config.boundary_timestep if 'boundary_timestep' in config.__dict__ else 875 max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) - - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) print("compile time: ", (time.perf_counter() - s0)) saved_video_path = [] @@ -155,55 +141,20 @@ def run(config, pipeline=None, filename_prefix=""): upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) s0 = time.perf_counter() - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) print("generation time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) print("generation time: ", (time.perf_counter() - s0)) return saved_video_path def main(argv: Sequence[str]) -> None: - # Get the model name from the flag - model_key = _MODEL_NAME.value - config_filename = MODEL_CONFIG_MAP[model_key] - selected_yaml_path = os.path.join(CONFIG_BASE_DIR, config_filename) - - max_logging.log(f"Using model: {model_key}, loading base config: {selected_yaml_path}") - - # Construct argv for pyconfig.initialize - # argv[0] is the program name. - # Insert the selected YAML path at index 1. - # The rest of argv (argv[1:]) are the overrides. - argv_for_pyconfig = list(argv[:1]) + [selected_yaml_path] + list(argv[1:]) - - pyconfig.initialize(argv_for_pyconfig) + pyconfig.initialize(argv) flax.config.update("flax_always_shard_variable", False) run(pyconfig.config) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index e6d3df1dc..e2bfd7e86 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union, Optional +from abc import abstractmethod +from typing import List, Union, Optional, Type from functools import partial import numpy as np import jax @@ -187,13 +188,11 @@ class WanPipeline: vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - + _SUBCLASS_MAP: dict[str, Type['WanPipeline']] = {} def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - low_noise_transformer: WanModel, - high_noise_transformer: Optional[WanModel], vae: AutoencoderKLWan, vae_cache: AutoencoderKLWanCache, scheduler: FlaxUniPCMultistepScheduler, @@ -204,8 +203,6 @@ def __init__( ): self.tokenizer = tokenizer self.text_encoder = text_encoder - self.low_noise_transformer = low_noise_transformer - self.high_noise_transformer = high_noise_transformer self.vae = vae self.vae_cache = vae_cache self.scheduler = scheduler @@ -373,93 +370,6 @@ def load_scheduler(cls, config): ) return scheduler, scheduler_state - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - model_name = config.model_name - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") - if model_name == "wan2.2": - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - return WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - model_name = config.model_name - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - if model_name == "wan2.2": - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - pipeline = WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) - if model_name == "wan2.2": - pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) - return pipeline def _get_t5_prompt_embeds( self, @@ -549,25 +459,86 @@ def prepare_latents( return latents - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): + def _denormalize_latents(self, latents: jax.Array) -> jax.Array: + """Denormalizes latents using VAE statistics.""" + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(jnp.float32) + return latents + + def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: + """Decodes latents to video frames and postprocesses.""" + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] + + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + return self.video_processor.postprocess_video(video, output_type="np") + + @classmethod + def _create_common_components(cls, config, vae_only=False): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + components = { + "vae": wan_vae, "vae_cache": vae_cache, + "devices_array": devices_array, "rngs": rngs, "mesh": mesh, + "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None + } + + if not vae_only: + components["tokenizer"] = cls.load_tokenizer(config=config) + components["text_encoder"] = cls.load_text_encoder(config=config) + components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) + return components + + @classmethod + def _get_subclass(cls, model_key: str) -> Type['WanPipeline']: + subclass = cls._SUBCLASS_MAP.get(model_key) + if subclass is None: + raise ValueError( + f"Unknown model_key for WanPipeline: '{model_key}'. " + f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}" + ) + return subclass + + @classmethod + def from_checkpoint(cls, model_key: str, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + subclass = cls._get_subclass(model_key) + return subclass.from_checkpoint(config, restored_checkpoint=restored_checkpoint, vae_only=vae_only, load_transformer=load_transformer) + + @classmethod + def from_pretrained(cls, model_key: str, config: HyperParameters, vae_only=False, load_transformer=True): + subclass = cls._get_subclass(model_key) + return subclass.from_pretrained(config, vae_only=vae_only, load_transformer=load_transformer) + + @abstractmethod + def _get_num_channel_latents(self) -> int: + """Returns the number of input channels for the transformer.""" + pass + + def _prepare_call_inputs( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: max_logging.log( @@ -590,7 +561,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - num_channel_latents = self.low_noise_transformer.config.in_channels + num_channel_latents = self._get_num_channel_latents() if latents is None: latents = self.prepare_latents( batch_size=batch_size, @@ -615,49 +586,235 @@ def __call__( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) - low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) - high_noise_graphdef, high_noise_state, high_noise_rest = None, None, None - if self.model_name == "wan2.2": - high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) - - p_run_inference = partial( - run_inference, - model_name=self.model_name, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - ) + return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, - low_noise_state=low_noise_state, - low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, - high_noise_state=high_noise_state, - high_noise_rest=high_noise_rest, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + @abstractmethod + def __call__(self, **kwargs): + """Runs the inference pipeline.""" + pass + +class WanPipeline2_1(WanPipeline): + """Pipeline for WAN 2.1 with a single transformer.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.transformer = transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + transformer = None + if not vae_only: + if load_transformer: + transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + + return pipeline, transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def _get_num_channel_latents(self) -> int: + return self.transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_1, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] + latents = p_run_inference( + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) + +class WanPipeline2_2(WanPipeline): + """Pipeline for WAN 2.2 with dual transformers.""" + def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + low_noise_transformer, high_noise_transformer = None, None + if not vae_only and load_transformer: + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2" + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + return pipeline, low_noise_transformer, high_noise_transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) + low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) + high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - video = self.video_processor.postprocess_video(video, output_type="np") - return video + def _get_num_channel_latents(self) -> int: + return self.low_noise_transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_2, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( @@ -681,7 +838,42 @@ def transformer_forward_pass( return noise_pred, latents -def run_inference( +def run_inference_2_1( + graphdef, + sharded_state, + rest_of_state, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale: float, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, latents.shape[0]) + + noise_pred, latents = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale, + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents + +def run_inference_2_2( low_noise_graphdef, low_noise_state, low_noise_rest, @@ -691,8 +883,6 @@ def run_inference( latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, - model_name: str, - guidance_scale: float, guidance_scale_low: float, guidance_scale_high: float, boundary: int, @@ -700,9 +890,7 @@ def run_inference( scheduler: FlaxUniPCMultistepScheduler, scheduler_state, ): - do_classifier_free_guidance = guidance_scale > 1.0 - if model_name == "wan2.2": - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) @@ -728,11 +916,6 @@ def high_noise_branch(operands): latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) - if model_name == "wan2.1": - noise_pred, latents = low_noise_branch((latents, timestep, prompt_embeds)) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - continue - use_high_noise = jnp.greater_equal(t, boundary) noise_pred, latents = jax.lax.cond( @@ -744,3 +927,6 @@ def high_noise_branch(operands): latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents + +WanPipeline._SUBCLASS_MAP["wan2.1"] = WanPipeline2_1 +WanPipeline._SUBCLASS_MAP["wan2.2"] = WanPipeline2_2 diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index ab5b5ca3a..554c8824c 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -14,10 +14,10 @@ import unittest from unittest.mock import patch, MagicMock -from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 - -class WanCheckpointerTest(unittest.TestCase): +class WanCheckpointer2_1Test(unittest.TestCase): + """Tests for WAN 2.1 checkpointer.""" def setUp(self): self.config = MagicMock() @@ -25,7 +25,7 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = None @@ -34,7 +34,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() @@ -44,7 +44,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -57,12 +57,6 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag restored_mock.wan_config = {} restored_mock.keys.return_value = ["wan_state", "wan_config"] - def getitem_side_effect(key): - if key == "wan_state": - return restored_mock.wan_state - raise KeyError(key) - - restored_mock.__getitem__.side_effect = getitem_side_effect mock_manager.restore.return_value = restored_mock mock_create_manager.return_value = mock_manager @@ -70,7 +64,7 @@ def getitem_side_effect(key): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -80,7 +74,7 @@ def getitem_side_effect(key): self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -93,12 +87,102 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man restored_mock.wan_config = {} restored_mock.keys.return_value = ["wan_state", "wan_config"] - def getitem_side_effect(key): - if key == "wan_state": - return restored_mock.wan_state - raise KeyError(key) + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + self.assertEqual(step, 1) + + +class WanCheckpointer2_2Test(unittest.TestCase): + """Tests for WAN 2.2 checkpointer.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_2_2_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + """Test loading from pretrained when no checkpoint exists.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint without optimizer state.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with optimizer state in low_noise_transformer.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - restored_mock.__getitem__.side_effect = getitem_side_effect mock_manager.restore.return_value = restored_mock mock_create_manager.return_value = mock_manager @@ -106,7 +190,7 @@ def getitem_side_effect(key): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -116,6 +200,104 @@ def getitem_side_effect(key): self.assertEqual(opt_state["learning_rate"], 0.001) self.assertEqual(step, 1) + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with optimizer state in high_noise_transformer.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.002}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.002) + self.assertEqual(step, 1) + + +class WanCheckpointerEdgeCasesTest(unittest.TestCase): + """Tests for edge cases and error handling.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_edge_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") + def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with explicit None step falls back to latest.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 5 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {"params": {}} + restored_mock.wan_config = {} + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + self.assertEqual(step, 5) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_both_optimizers_present(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint when both transformers have optimizer state (prioritize low_noise).""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.002}} + restored_mock.wan_config = {} + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + # Should prioritize low_noise_transformer's optimizer state + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) From ce17ed0f41b9944af6928d53df8a34fb2a488703 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 12 Nov 2025 00:12:50 +0530 Subject: [PATCH 08/28] Removed extra files --- .../checkpointing/wan_checkpointer2_2.py | 207 ----- .../pipelines/wan/wan_pipeline2_2.py | 725 ------------------ .../tests/wan_checkpointer2_2_test.py | 113 --- 3 files changed, 1045 deletions(-) delete mode 100644 src/maxdiffusion/checkpointing/wan_checkpointer2_2.py delete mode 100644 src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py delete mode 100644 src/maxdiffusion/tests/wan_checkpointer2_2_test.py diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py deleted file mode 100644 index de8bb35d6..000000000 --- a/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py +++ /dev/null @@ -1,207 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from abc import ABC -import json - -import jax -import numpy as np -from typing import Optional, Tuple -from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) -from ..pipelines.wan.wan_pipeline2_2 import WanPipeline -from .. import max_logging, max_utils -import orbax.checkpoint as ocp -from etils import epath - -WAN_CHECKPOINT = "WAN_CHECKPOINT" - - -class WanCheckpointer(ABC): - - def __init__(self, config, checkpoint_type): - self.config = config - self.checkpoint_type = checkpoint_type - self.opt_state = None - - self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, - ) - - def _create_optimizer(self, model, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( - learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps - ) - tx = max_utils.create_optimizer(config, learning_rate_scheduler) - return tx, learning_rate_scheduler - - def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: - if step is None: - step = self.checkpoint_manager.latest_step() - max_logging.log(f"Latest WAN checkpoint step: {step}") - if step is None: - max_logging.log("No WAN checkpoint found.") - return None, None - max_logging.log(f"Loading WAN checkpoint from step {step}") - metadatas = self.checkpoint_manager.item_metadata(step) - - low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) - low_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_params, - ) - ) - - high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) - high_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_params, - ) - ) - - max_logging.log("Restoring WAN checkpoint") - restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), - step=step, - args=ocp.args.Composite( - low_noise_transformer_state=low_params_restore, - high_noise_transformer_state=high_params_restore, - wan_config=ocp.args.JsonRestore(), - ), - ) - max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") - return restored_checkpoint, step - - def load_diffusers_checkpoint(self): - pipeline = WanPipeline.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer - if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): - opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] - elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): - opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step - - def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - return json.loads(model_or_config.to_json_string()) - - max_logging.log(f"Saving checkpoint for step {train_step}") - items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), - } - - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") - - -def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - """ - only save the config that is needed and can be serialized to JSON. - """ - if not hasattr(model_or_config, "config"): - return None - source_config = dict(model_or_config.config) - - # 1. configs that can be serialized to JSON - SAFE_KEYS = [ - "_class_name", - "_diffusers_version", - "model_type", - "patch_size", - "num_attention_heads", - "attention_head_dim", - "in_channels", - "out_channels", - "text_dim", - "freq_dim", - "ffn_dim", - "num_layers", - "cross_attn_norm", - "qk_norm", - "eps", - "image_dim", - "added_kv_proj_dim", - "rope_max_seq_len", - "pos_embed_seq_len", - "flash_min_seq_length", - "flash_block_sizes", - "attention", - "_use_default_values", - ] - - # 2. save the config that are in the SAFE_KEYS list - clean_config = {} - for key in SAFE_KEYS: - if key in source_config: - clean_config[key] = source_config[key] - - # 3. deal with special data type and precision - if "dtype" in source_config and hasattr(source_config["dtype"], "name"): - clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16' - - if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"): - clean_config["weights_dtype"] = source_config["weights_dtype"].name - - if "precision" in source_config and isinstance(source_config["precision"]): - clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST' - - return clean_config - - items_to_save = { - "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), - } - - items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) - - # Create CompositeArgs for Orbax - save_args = ocp.args.Composite(**items_to_save) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=save_args) - max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py deleted file mode 100644 index 0645aeeb6..000000000 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py +++ /dev/null @@ -1,725 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Union, Optional -from functools import partial -import numpy as np -import jax -import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -import flax -import flax.linen as nn -from flax import nnx -from flax.linen import partitioning as nn_partitioning -from ...pyconfig import HyperParameters -from ... import max_logging -from ... import max_utils -from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated -from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae -from ...models.wan.transformers.transformer_wan import WanModel -from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache -from maxdiffusion.video_processor import VideoProcessor -from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState -from transformers import AutoTokenizer, UMT5EncoderModel -from maxdiffusion.utils.import_utils import is_ftfy_available -from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs -import html -import re -import torch -import qwix - - -def cast_with_exclusion(path, x, dtype_to_cast): - """ - Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32. - """ - - exclusion_keywords = [ - "norm", # For all LayerNorm/GroupNorm layers - "condition_embedder", # The entire time/text conditioning module - "scale_shift_table", # Catches both the final and the AdaLN tables - ] - - path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path) - - if any(keyword in path_str.lower() for keyword in exclusion_keywords): - print("is_norm_path: ", path) - # Keep LayerNorm/GroupNorm weights and biases in full precision - return x.astype(jnp.float32) - else: - # Cast everything else to dtype_to_cast - return x.astype(dtype_to_cast) - - -def basic_clean(text): - if is_ftfy_available(): - import ftfy - - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text - - -def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState: - vs.sharding_rules = logical_axis_rules - return vs - - -# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. -def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" -): - - def create_model(rngs: nnx.Rngs, wan_config: dict): - wan_transformer = WanModel(**wan_config, rngs=rngs) - return wan_transformer - - # 1. Load config. - if restored_checkpoint: - wan_config = restored_checkpoint["wan_config"] - else: - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) - wan_config["mesh"] = mesh - wan_config["dtype"] = config.activations_dtype - wan_config["weights_dtype"] = config.weights_dtype - wan_config["attention"] = config.attention - wan_config["precision"] = get_precision(config) - wan_config["flash_block_sizes"] = get_flash_block_sizes(config) - wan_config["remat_policy"] = config.remat_policy - wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved - wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded - wan_config["flash_min_seq_length"] = config.flash_min_seq_length - wan_config["dropout"] = config.dropout - wan_config["scan_layers"] = config.scan_layers - - # 2. eval_shape - will not use flops or create weights on device - # thus not using HBM memory. - p_model_factory = partial(create_model, wan_config=wan_config) - wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) - graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) - - # 3. retrieve the state shardings, mapping logical names to mesh axis names. - logical_state_spec = nnx.get_partition_spec(state) - logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) - logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) - params = state.to_pure_dict() - state = dict(nnx.to_flat_state(state)) - - # 4. Load pretrained weights and move them to device using the state shardings from (3) above. - # This helps with loading sharded weights directly into the accelerators without fist copying them - # all to one device and then distributing them, thus using low HBM memory. - if restored_checkpoint: - if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer - params = restored_checkpoint["wan_state"]["params"] - else: # if not checkpointed with optimizer - params = restored_checkpoint["wan_state"] - else: - params = load_wan_transformer( - config.wan_transformer_pretrained_model_name_or_path, - params, - "cpu", - num_layers=wan_config["num_layers"], - scan_layers=config.scan_layers, - subfolder=subfolder, - ) - - params = jax.tree_util.tree_map_with_path( - lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params - ) - for path, val in flax.traverse_util.flatten_dict(params).items(): - if restored_checkpoint: - path = path[:-1] - sharding = logical_state_sharding[path].value - state[path].value = device_put_replicated(val, sharding) - state = nnx.from_flat_state(state) - - wan_transformer = nnx.merge(graphdef, state, rest_of_state) - return wan_transformer - - -@nnx.jit(static_argnums=(1,), donate_argnums=(0,)) -def create_sharded_logical_model(model, logical_axis_rules): - graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) - p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules) - state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) - pspecs = nnx.get_partition_spec(state) - sharded_state = jax.lax.with_sharding_constraint(state, pspecs) - model = nnx.merge(graphdef, sharded_state, rest_of_state) - return model - - -class WanPipeline: - r""" - Pipeline for text-to-video generation using Wan. - - tokenizer ([`T5Tokenizer`]): - Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), - specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - text_encoder ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanModel`]): - Conditional Transformer to denoise the input latents. - scheduler ([`FlaxUniPCMultistepScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - """ - - def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - low_noise_transformer: WanModel, - high_noise_transformer: WanModel, - vae: AutoencoderKLWan, - vae_cache: AutoencoderKLWanCache, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state: UniPCMultistepSchedulerState, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters, - ): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.low_noise_transformer = low_noise_transformer - self.high_noise_transformer = high_noise_transformer - self.vae = vae - self.vae_cache = vae_cache - self.scheduler = scheduler - self.scheduler_state = scheduler_state - self.devices_array = devices_array - self.mesh = mesh - self.config = config - - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - - self.p_run_inference = None - - @classmethod - def load_text_encoder(cls, config: HyperParameters): - text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder", - ) - return text_encoder - - @classmethod - def load_tokenizer(cls, config: HyperParameters): - tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="tokenizer", - ) - return tokenizer - - @classmethod - def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - - def create_model(rngs: nnx.Rngs, config: HyperParameters): - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh, - dtype=jnp.float32, - weights_dtype=jnp.float32, - ) - return wan_vae - - # 1. eval shape - p_model_factory = partial(create_model, config=config) - wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) - graphdef, state = nnx.split(wan_vae, nnx.Param) - - # 2. retrieve the state shardings, mapping logical names to mesh axis names. - logical_state_spec = nnx.get_partition_spec(state) - logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) - logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) - params = state.to_pure_dict() - state = dict(nnx.to_flat_state(state)) - - # 4. Load pretrained weights and move them to device using the state shardings from (3) above. - # This helps with loading sharded weights directly into the accelerators without fist copying them - # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") - params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - for path, val in flax.traverse_util.flatten_dict(params).items(): - sharding = logical_state_sharding[path].value - if config.replicate_vae: - sharding = NamedSharding(mesh, P()) - state[path].value = device_put_replicated(val, sharding) - state = nnx.from_flat_state(state) - - wan_vae = nnx.merge(graphdef, state) - vae_cache = AutoencoderKLWanCache(wan_vae) - return wan_vae, vae_cache - - @classmethod - def get_basic_config(cls, dtype, config: HyperParameters): - rules = [ - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=dtype, - act_qtype=dtype, - op_names=("dot_general", "einsum", "conv_general_dilated"), - ) - ] - return rules - - @classmethod - def get_fp8_config(cls, config: HyperParameters): - """ - fp8 config rules with per-tensor calibration. - FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api): - The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice. - """ - rules = [ - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, - op_names=("dot_general", "einsum"), - ), - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, - op_names=("conv_general_dilated"), - ), - ] - return rules - - @classmethod - def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: - """Get quantization rules based on the config.""" - if not getattr(config, "use_qwix_quantization", False): - return None - - match config.quantization: - case "int8": - return qwix.QtProvider(cls.get_basic_config(jnp.int8, config)) - case "fp8": - return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config)) - case "fp8_full": - return qwix.QtProvider(cls.get_fp8_config(config)) - return None - - @classmethod - def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh): - """Quantizes the transformer model.""" - q_rules = cls.get_qt_provider(config) - if not q_rules: - return model - max_logging.log("Quantizing transformer with Qwix.") - - batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) - latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) - model_inputs = (latents, timesteps, prompt_embeds) - with mesh: - quantized_model = qwix.quantize_model(model, q_rules, *model_inputs) - max_logging.log("Qwix Quantization complete.") - return quantized_model - - @classmethod - def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): - with mesh: - wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder - ) - return wan_transformer - - @classmethod - def load_scheduler(cls, config): - scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="scheduler", - flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p - ) - return scheduler, scheduler_state - - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - return WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - pipeline = WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) - pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) - return pipeline - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32) - - if negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32) - - return prompt_embeds, negative_prompt_embeds - - def prepare_latents( - self, - batch_size: int, - vae_scale_factor_temporal: int, - vae_scale_factor_spatial: int, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_channels_latents: int = 16, - ): - rng = jax.random.key(self.config.seed) - num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_channels_latents, - num_latent_frames, - int(height) // vae_scale_factor_spatial, - int(width) // vae_scale_factor_spatial, - ) - latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) - - return latents - - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): - if not vae_only: - if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - prompt = [prompt] - - batch_size = len(prompt) - - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - num_channel_latents = self.low_noise_transformer.config.in_channels - if latents is None: - latents = self.prepare_latents( - batch_size=batch_size, - vae_scale_factor_temporal=self.vae_scale_factor_temporal, - vae_scale_factor_spatial=self.vae_scale_factor_spatial, - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channel_latents, - ) - - data_sharding = NamedSharding(self.mesh, P()) - # Using global_batch_size_to_train_on so not to create more config variables - if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - - latents = jax.device_put(latents, data_sharding) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) - - scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape - ) - - low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) - high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) - - p_run_inference = partial( - run_inference, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - ) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, - low_noise_state=low_noise_state, - low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, - high_noise_state=high_noise_state, - high_noise_rest=high_noise_rest, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] - - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - video = self.video_processor.postprocess_video(video, output_type="np") - return video - - -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) -def transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance, - guidance_scale, -): - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - if do_classifier_free_guidance: - bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] - - return noise_pred, latents - -def run_inference( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents: jnp.array, - prompt_embeds: jnp.array, - negative_prompt_embeds: jnp.array, - guidance_scale_low: float, - guidance_scale_high: float, - boundary: int, - num_inference_steps: int, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state, -): - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - - def low_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_low - ) - - def high_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_high - ) - - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - use_high_noise = jnp.greater_equal(t, boundary) - - noise_pred, latents = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latents, timestep, prompt_embeds) - ) - - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents diff --git a/src/maxdiffusion/tests/wan_checkpointer2_2_test.py b/src/maxdiffusion/tests/wan_checkpointer2_2_test.py deleted file mode 100644 index 8e1fa0be4..000000000 --- a/src/maxdiffusion/tests/wan_checkpointer2_2_test.py +++ /dev/null @@ -1,113 +0,0 @@ -""" - Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import unittest -from unittest.mock import patch, MagicMock - -from maxdiffusion.checkpointing.wan_checkpointer2_2 import WanCheckpointer, WAN_CHECKPOINT - - -class WanCheckpointerTest(unittest.TestCase): - - def setUp(self): - self.config = MagicMock() - self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" - self.config.dataset_type = "test_dataset" - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = None - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) - - mock_manager.latest_step.assert_called_once() - mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertIsNone(step) - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.low_noise_transformer_state = {} - metadata_mock.high_noise_transformer_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.low_noise_transformer_state = {"params": {}} - restored_mock.high_noise_transformer_state = {"params": {}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertEqual(step, 1) - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.low_noise_transformer_state = {} - metadata_mock.high_noise_transformer_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} - restored_mock.high_noise_transformer_state = {"params": {}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNotNone(opt_state) - self.assertEqual(opt_state["learning_rate"], 0.001) - self.assertEqual(step, 1) - - -if __name__ == "__main__": - unittest.main() From b7aad0aaa8d8bd3f7e1964f657d983f0a9aa2add Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 12 Nov 2025 08:51:07 +0530 Subject: [PATCH 09/28] Updated README and generate_wan.py --- README.md | 18 +----------------- src/maxdiffusion/generate_wan.py | 3 ++- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 6a4d2048d..2deb8ba96 100644 --- a/README.md +++ b/README.md @@ -482,23 +482,7 @@ To generate images, run the following command: ```bash HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 - ``` - ## Wan2.2 - - Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). - - ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 - ``` - ## Wan2.2 - - Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). - - ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 ``` ## Wan2.2 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 53a38ac45..fc3a3626a 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -119,7 +119,8 @@ def run(config, pipeline=None, filename_prefix=""): model_key = config.model_name checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) pipeline, _, _ = checkpoint_loader.load_checkpoint() - pipeline = WanPipeline.from_pretrained(model_key=model_key, config=config) + if pipeline is None: + pipeline = WanPipeline.from_pretrained(model_key=model_key, config=config) s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables From 16d657ab18984b0bc73af39e3b341146d017c01b Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 18 Nov 2025 18:28:57 +0530 Subject: [PATCH 10/28] Added tensorboard logging for inference metrics --- requirements_with_jax_ai_image.txt | 1 + src/maxdiffusion/generate_wan.py | 37 ++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 2a2287d60..c279edb88 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -30,6 +30,7 @@ orbax-checkpoint tokenizers==0.21.0 huggingface_hub>=0.30.2 transformers==4.48.1 +tokamax einops==0.8.0 sentencepiece aqtp diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index fc3a3626a..d75fd7eea 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -115,8 +115,10 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): - print("seed: ", config.seed) model_key = config.model_name + writer = max_utils.initialize_summary_writer(config) + if jax.process_index() == 0 and writer: + max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) pipeline, _, _ = checkpoint_loader.load_checkpoint() if pipeline is None: @@ -132,7 +134,19 @@ def run(config, pipeline=None, filename_prefix=""): ) videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("compile time: ", (time.perf_counter() - s0)) + max_logging.log("===================== Model details =======================") + max_logging.log(f"model name: {config.model_name}") + max_logging.log(f"model path: {config.pretrained_model_name_or_path}") + max_logging.log("model type: t2v") + max_logging.log(f"hardware: {jax.devices()[0].platform}") + max_logging.log(f"number of devices: {jax.device_count()}") + max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") + max_logging.log("============================================================") + + compile_time = time.perf_counter() - s0 + max_logging.log(f"compile_time: {compile_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/compile_time", compile_time, global_step=0) saved_video_path = [] for i in range(len(videos)): video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" @@ -143,14 +157,27 @@ def run(config, pipeline=None, filename_prefix=""): s0 = time.perf_counter() videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("generation time: ", (time.perf_counter() - s0)) - + generation_time = time.perf_counter() - s0 + max_logging.log(f"generation_time: {generation_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time", generation_time, global_step=0) + num_devices = jax.device_count() + num_videos = num_devices * config.per_device_batch_size + if num_videos > 0: + generation_time_per_video = generation_time / num_videos + writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) + max_logging.log(f"generation time per video: {generation_time_per_video}") + else: + max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) - print("generation time: ", (time.perf_counter() - s0)) + generation_time_with_profiler = time.perf_counter() - s0 + max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) return saved_video_path From cc78cac9205efff778207e12f002db3bc3b3a87c Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Thu, 20 Nov 2025 07:49:55 +0530 Subject: [PATCH 11/28] Fixed duplicate pipeline loading --- src/maxdiffusion/generate_wan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d75fd7eea..e23395a01 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -119,10 +119,10 @@ def run(config, pipeline=None, filename_prefix=""): writer = max_utils.initialize_summary_writer(config) if jax.process_index() == 0 and writer: max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") - checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) - pipeline, _, _ = checkpoint_loader.load_checkpoint() + if pipeline is None: - pipeline = WanPipeline.from_pretrained(model_key=model_key, config=config) + checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) + pipeline, _, _ = checkpoint_loader.load_checkpoint() s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables From c46fd87e5b23b17701276247bfddfe68be997cea Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Thu, 20 Nov 2025 08:57:56 +0530 Subject: [PATCH 12/28] Merge conflicts --- src/maxdiffusion/generate_wan.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e23395a01..4f506fdfa 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -63,6 +63,15 @@ def delete_file(file_path: str): jax.config.update("jax_use_shardy_partitioner", True) +jax.config.update("jax_default_prng_impl", "unsafe_rbg") + # TF allocates extraneous GPU memory when using TFDS data + # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF + # tf.config.set_visible_devices([], "GPU") +if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + max_logging.log("Enabling unsafe RNG bit generator for TPU SPMD.") + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) def call_pipeline(config, pipeline, prompt, negative_prompt): model_key = config.model_name From e55ccd233e854d0370b09a5406e486abf32f58db Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Thu, 20 Nov 2025 09:21:39 +0530 Subject: [PATCH 13/28] ruff errors --- src/maxdiffusion/generate_wan.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 35b553f5a..e422afa1e 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -16,7 +16,6 @@ import jax import time import os -from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer from maxdiffusion import pyconfig, max_logging, max_utils from absl import app @@ -128,7 +127,7 @@ def run(config, pipeline=None, filename_prefix=""): writer = max_utils.initialize_summary_writer(config) if jax.process_index() == 0 and writer: max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") - + if pipeline is None: checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) pipeline, _, _ = checkpoint_loader.load_checkpoint() From ebc7eec13e585399ef09da66109ab37a4f322705 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Thu, 20 Nov 2025 11:21:30 +0530 Subject: [PATCH 14/28] Changes to Wan trainer for compatibility with checkpointer --- src/maxdiffusion/trainers/wan_trainer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index fb01a4f44..250daa3f3 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -85,12 +85,15 @@ def print_ssim(pretrained_video_path, posttrained_video_path): max_logging.log(f"SSIM score after training is {ssim_compare}") -class WanTrainer(WanCheckpointer): +class WanTrainer: def __init__(self, config): - WanCheckpointer.__init__(self, config, WAN_CHECKPOINT) + # WanCheckpointer.__init__(self, config, WAN_CHECKPOINT) if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") + self.config = config + model_key = config.model_name + self.checkpointer = WanCheckpointer(model_key, config, WAN_CHECKPOINT) def post_training_steps(self, pipeline, params, train_states, msg=""): pass @@ -210,7 +213,7 @@ def prepare_sample_eval(features): def start_training(self): - pipeline, opt_state, step = self.load_checkpoint() + pipeline, opt_state, step = self.checkpointer.load_checkpoint() restore_args = {} if opt_state and step: restore_args = {"opt_state": opt_state, "step": step} @@ -231,7 +234,7 @@ def start_training(self): scheduler, scheduler_state = self.create_scheduler() pipeline.scheduler = scheduler pipeline.scheduler_state = scheduler_state - optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) + optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer(pipeline.transformer, self.config, 1e-5) # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) @@ -392,9 +395,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") if self.config.save_optimizer: - self.save_checkpoint(step, pipeline, state) + self.checkpointer.save_checkpoint(step, pipeline, state) else: - self.save_checkpoint(step, pipeline, state.params) + self.checkpointer.save_checkpoint(step, pipeline, state.params) _metrics_queue.put(None) writer_thread.join() @@ -402,8 +405,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data writer.flush() if self.config.save_final_checkpoint: max_logging.log(f"Saving final checkpoint for step {step}") - self.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) - self.checkpoint_manager.wait_until_finished() + self.checkpointer.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) + self.checkpointer.checkpoint_manager.wait_until_finished() # load new state for trained tranformer pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) return pipeline From d6cdb1ef704b6938f27d6e47d20014a935198a48 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Thu, 20 Nov 2025 12:34:03 +0530 Subject: [PATCH 15/28] flash block size changed for testing --- src/maxdiffusion/configs/base_wan_14b.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 78dca3be4..5bec7ad07 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -74,13 +74,13 @@ attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { - "block_q" : 3024, + "block_q" : 2048, "block_kv_compute" : 1024, "block_kv" : 2048, - "block_q_dkv" : 3024, + "block_q_dkv" : 2048, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, - "block_q_dq" : 3024, + "block_q_dq" : 2048, "block_kv_dq" : 2048 } # Use on v6e From b3edab669a598cc7c36591e7589d98fc29655766 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Thu, 20 Nov 2025 17:25:02 +0530 Subject: [PATCH 16/28] Revert "flash block size changed for testing" This reverts commit d6cdb1ef704b6938f27d6e47d20014a935198a48. --- src/maxdiffusion/configs/base_wan_14b.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 5bec7ad07..78dca3be4 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -74,13 +74,13 @@ attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { - "block_q" : 2048, + "block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, - "block_q_dkv" : 2048, + "block_q_dkv" : 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, - "block_q_dq" : 2048, + "block_q_dq" : 3024, "block_kv_dq" : 2048 } # Use on v6e From 2c494d1d585399c5255aadcef01c42b909cf8df3 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Thu, 20 Nov 2025 20:18:15 +0530 Subject: [PATCH 17/28] Raise error for unsupported model training --- src/maxdiffusion/trainers/wan_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 250daa3f3..68979b990 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -93,6 +93,8 @@ def __init__(self, config): raise ValueError("this script currently doesn't support training text_encoders") self.config = config model_key = config.model_name + if model_key != 'wan2.1': + raise ValueError(f"Unsupported model_name: '{model_key}'. This trainer only supports 'wan2.1'.") self.checkpointer = WanCheckpointer(model_key, config, WAN_CHECKPOINT) def post_training_steps(self, pipeline, params, train_states, msg=""): From 5e642f8748f901162c7f3f84ebacfbda36171045 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 22 Nov 2025 13:34:27 +0530 Subject: [PATCH 18/28] Explicitly instantiate WanPipeline and WanCheckpointer subclasses --- .../checkpointing/wan_checkpointer.py | 18 +------------- src/maxdiffusion/generate_wan.py | 9 +++++-- .../pipelines/wan/wan_pipeline.py | 24 ------------------- .../tests/wan_checkpointer_test.py | 18 +++++++------- src/maxdiffusion/trainers/wan_trainer.py | 5 ++-- 5 files changed, 19 insertions(+), 55 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 74710f4fd..abccea693 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -31,21 +31,8 @@ class WanCheckpointer(ABC): - _SUBCLASS_MAP: dict[str, Type['WanCheckpointer']] = {} - - def __new__(cls, model_key: str, config, checkpoint_type: str = WAN_CHECKPOINT): - if cls is WanCheckpointer: - subclass = cls._SUBCLASS_MAP.get(model_key) - if subclass is None: - raise ValueError( - f"Unknown model_key: '{model_key}'. " - f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}" - ) - return super().__new__(subclass) - else: - return super().__new__(cls) - def __init__(self, model_key, config, checkpoint_type: str = WAN_CHECKPOINT): + def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT): self.config = config self.checkpoint_type = checkpoint_type self.opt_state = None @@ -244,9 +231,6 @@ def config_to_json(model_or_config): self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) max_logging.log(f"Checkpoint for step {train_step} saved.") -WanCheckpointer._SUBCLASS_MAP["wan2.1"] = WanCheckpointer2_1 -WanCheckpointer._SUBCLASS_MAP["wan2.2"] = WanCheckpointer2_2 - def save_checkpoint_orig(self, train_step, pipeline, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e422afa1e..dabbf4b95 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -16,7 +16,7 @@ import jax import time import os -from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 from maxdiffusion import pyconfig, max_logging, max_utils from absl import app from maxdiffusion.utils import export_to_video @@ -129,7 +129,12 @@ def run(config, pipeline=None, filename_prefix=""): max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") if pipeline is None: - checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) + if model_key == "wan2.1": + checkpoint_loader = WanCheckpointer2_1(config=config) + elif model_key == "wan2.2": + checkpoint_loader = WanCheckpointer2_2(config=config) + else: + raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") pipeline, _, _ = checkpoint_loader.load_checkpoint() s0 = time.perf_counter() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 9ad1768a6..55b019155 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -189,7 +189,6 @@ class WanPipeline: vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - _SUBCLASS_MAP: dict[str, Type['WanPipeline']] = {} def __init__( self, tokenizer: AutoTokenizer, @@ -500,26 +499,6 @@ def _create_common_components(cls, config, vae_only=False): components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) return components - @classmethod - def _get_subclass(cls, model_key: str) -> Type['WanPipeline']: - subclass = cls._SUBCLASS_MAP.get(model_key) - if subclass is None: - raise ValueError( - f"Unknown model_key for WanPipeline: '{model_key}'. " - f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}" - ) - return subclass - - @classmethod - def from_checkpoint(cls, model_key: str, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - subclass = cls._get_subclass(model_key) - return subclass.from_checkpoint(config, restored_checkpoint=restored_checkpoint, vae_only=vae_only, load_transformer=load_transformer) - - @classmethod - def from_pretrained(cls, model_key: str, config: HyperParameters, vae_only=False, load_transformer=True): - subclass = cls._get_subclass(model_key) - return subclass.from_pretrained(config, vae_only=vae_only, load_transformer=load_transformer) - @abstractmethod def _get_num_channel_latents(self) -> int: """Returns the number of input channels for the transformer.""" @@ -929,6 +908,3 @@ def high_noise_branch(operands): latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents - -WanPipeline._SUBCLASS_MAP["wan2.1"] = WanPipeline2_1 -WanPipeline._SUBCLASS_MAP["wan2.2"] = WanPipeline2_2 diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index 554c8824c..79f050c07 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -34,7 +34,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() @@ -64,7 +64,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -94,7 +94,7 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -124,7 +124,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() @@ -157,7 +157,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -190,7 +190,7 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mo mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -224,7 +224,7 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -263,7 +263,7 @@ def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_c mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() @@ -291,7 +291,7 @@ def test_load_checkpoint_both_optimizers_present(self, mock_wan_pipeline, mock_c mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) # Should prioritize low_noise_transformer's optimizer state diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 68979b990..af3e9eb29 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -29,7 +29,7 @@ from maxdiffusion.schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning from maxdiffusion import max_utils, max_logging, train_utils -from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) +from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer2_1, WAN_CHECKPOINT) from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.generate_wan import run as generate_wan from maxdiffusion.generate_wan import inference_generate_video @@ -88,14 +88,13 @@ def print_ssim(pretrained_video_path, posttrained_video_path): class WanTrainer: def __init__(self, config): - # WanCheckpointer.__init__(self, config, WAN_CHECKPOINT) if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") self.config = config model_key = config.model_name if model_key != 'wan2.1': raise ValueError(f"Unsupported model_name: '{model_key}'. This trainer only supports 'wan2.1'.") - self.checkpointer = WanCheckpointer(model_key, config, WAN_CHECKPOINT) + self.checkpointer = WanCheckpointer2_1(config=config) def post_training_steps(self, pipeline, params, train_states, msg=""): pass From 531e64de55dc3953752f6ac12da25e86929855fa Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 22 Nov 2025 13:38:33 +0530 Subject: [PATCH 19/28] ruff errors --- src/maxdiffusion/checkpointing/wan_checkpointer.py | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 2 +- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index abccea693..12151bff7 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -19,7 +19,7 @@ import jax import numpy as np -from typing import Optional, Tuple, Type +from typing import Optional, Tuple from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2 from .. import max_logging, max_utils diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 55b019155..23ff46b81 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod -from typing import List, Union, Optional, Type +from typing import List, Union, Optional from functools import partial import numpy as np import jax diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index af3e9eb29..deecfd433 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -29,7 +29,7 @@ from maxdiffusion.schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning from maxdiffusion import max_utils, max_logging, train_utils -from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer2_1, WAN_CHECKPOINT) +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1 from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.generate_wan import run as generate_wan from maxdiffusion.generate_wan import inference_generate_video From 598b0bc09ed7d5bee0b14d2f844c19c0536e794e Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 24 Nov 2025 16:06:02 +0530 Subject: [PATCH 20/28] Added commit_id to tensorboard logging --- src/maxdiffusion/generate_wan.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index dabbf4b95..f10e1b9d1 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -16,6 +16,7 @@ import jax import time import os +import subprocess from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 from maxdiffusion import pyconfig, max_logging, max_utils from absl import app @@ -60,6 +61,17 @@ def delete_file(file_path: str): else: max_logging.log(f"The file '{file_path}' does not exist.") +def get_git_commit_hash(): + """Tries to get the current Git commit hash.""" + try: + commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8') + return commit_hash + except subprocess.CalledProcessError: + max_logging.log("Warning: 'git rev-parse HEAD' failed. Not running in a git repo?") + return None + except FileNotFoundError: + max_logging.log("Warning: 'git' command not found.") + return None jax.config.update("jax_use_shardy_partitioner", True) jax.config.update("jax_default_prng_impl", "unsafe_rbg") @@ -128,6 +140,19 @@ def run(config, pipeline=None, filename_prefix=""): if jax.process_index() == 0 and writer: max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") + commit_hash = get_git_commit_hash() + if commit_hash: + writer.add_text("inference/git_commit_hash", commit_hash, global_step=0) + max_logging.log(f"Git Commit Hash: {commit_hash}") + else: + # Fallback for CI environments like GitHub Actions + github_sha = os.environ.get('GITHUB_SHA') + if github_sha: + writer.add_text("inference/git_commit_hash", github_sha, global_step=0) + max_logging.log(f"Git Commit Hash (from GITHUB_SHA): {github_sha}") + else: + max_logging.log("Could not retrieve Git commit hash.") + if pipeline is None: if model_key == "wan2.1": checkpoint_loader = WanCheckpointer2_1(config=config) From 81dab2743ed916b9ad7d2a533e835cf40f412d53 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 24 Nov 2025 16:13:30 +0530 Subject: [PATCH 21/28] Commit hash logging --- src/maxdiffusion/generate_wan.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index f10e1b9d1..8d3768006 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -145,13 +145,7 @@ def run(config, pipeline=None, filename_prefix=""): writer.add_text("inference/git_commit_hash", commit_hash, global_step=0) max_logging.log(f"Git Commit Hash: {commit_hash}") else: - # Fallback for CI environments like GitHub Actions - github_sha = os.environ.get('GITHUB_SHA') - if github_sha: - writer.add_text("inference/git_commit_hash", github_sha, global_step=0) - max_logging.log(f"Git Commit Hash (from GITHUB_SHA): {github_sha}") - else: - max_logging.log("Could not retrieve Git commit hash.") + max_logging.log("Could not retrieve Git commit hash.") if pipeline is None: if model_key == "wan2.1": From 6d3b5974cd10831122892a15de287eb4f75252b2 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 25 Nov 2025 20:30:18 +0530 Subject: [PATCH 22/28] Added enable_jax_named_scopes param for wan 2.2 --- src/maxdiffusion/configs/base_wan_27b.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 2a998d0cf..217593ff9 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -271,6 +271,10 @@ enable_profiler: False skip_first_n_steps_for_profiler: 5 profiler_steps: 10 +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + # Generation parameters prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." From 168b39ad35603189e14c4bc8a6b7d2f9e18f61b0 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 3 Dec 2025 14:03:15 +0530 Subject: [PATCH 23/28] Wanpipeline and WanCheckpointer split into files --- .../checkpointing/wan_checkpointer.py | 169 +--------- .../checkpointing/wan_checkpointer_2_1.py | 95 ++++++ .../checkpointing/wan_checkpointer_2_2.py | 114 +++++++ src/maxdiffusion/common_types.py | 5 +- src/maxdiffusion/configs/base_wan_27b.yml | 8 + src/maxdiffusion/generate_wan.py | 9 +- .../pipelines/wan/wan_pipeline.py | 317 +----------------- .../pipelines/wan/wan_pipeline_2_1.py | 162 +++++++++ .../pipelines/wan/wan_pipeline_2_2.py | 202 +++++++++++ src/maxdiffusion/pyconfig.py | 22 +- .../tests/wan_checkpointer_test.py | 4 +- src/maxdiffusion/train_wan.py | 2 +- src/maxdiffusion/trainers/wan_trainer.py | 3 - 13 files changed, 618 insertions(+), 494 deletions(-) create mode 100644 src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py create mode 100644 src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py create mode 100644 src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py create mode 100644 src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 12151bff7..b601cb349 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,16 +15,12 @@ """ from abc import ABC, abstractmethod -import json - -import jax -import numpy as np from typing import Optional, Tuple from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) -from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2 +from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 +from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 from .. import max_logging, max_utils import orbax.checkpoint as ocp -from etils import epath WAN_CHECKPOINT = "WAN_CHECKPOINT" @@ -70,167 +66,6 @@ def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipel def save_checkpoint(self, train_step, pipeline, train_states: dict): raise NotImplementedError - -class WanCheckpointer2_1(WanCheckpointer): - - def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: - if step is None: - step = self.checkpoint_manager.latest_step() - max_logging.log(f"Latest WAN checkpoint step: {step}") - if step is None: - max_logging.log("No WAN checkpoint found.") - return None, None - max_logging.log(f"Loading WAN checkpoint from step {step}") - metadatas = self.checkpoint_manager.item_metadata(step) - transformer_metadata = metadatas.wan_state - abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) - params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_params, - ) - ) - - max_logging.log("Restoring WAN checkpoint") - restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), - step=step, - args=ocp.args.Composite( - wan_state=params_restore, - wan_config=ocp.args.JsonRestore(), - ), - ) - max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") - max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") - max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") - return restored_checkpoint, step - - def load_diffusers_checkpoint(self): - pipeline = WanPipeline2_1.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) - if "opt_state" in restored_checkpoint.wan_state.keys(): - opt_state = restored_checkpoint.wan_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step - - def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - return json.loads(model_or_config.to_json_string()) - - max_logging.log(f"Saving checkpoint for step {train_step}") - items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), - } - - items["wan_state"] = ocp.args.PyTreeSave(train_states) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") - - -class WanCheckpointer2_2(WanCheckpointer): - - def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: - if step is None: - step = self.checkpoint_manager.latest_step() - max_logging.log(f"Latest WAN checkpoint step: {step}") - if step is None: - max_logging.log("No WAN checkpoint found.") - return None, None - max_logging.log(f"Loading WAN checkpoint from step {step}") - metadatas = self.checkpoint_manager.item_metadata(step) - - # Handle low_noise_transformer - low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) - low_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_params, - ) - ) - - # Handle high_noise_transformer - high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) - high_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_params, - ) - ) - - max_logging.log("Restoring WAN 2.2 checkpoint") - restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), - step=step, - args=ocp.args.Composite( - low_noise_transformer_state=low_params_restore, - high_noise_transformer_state=high_params_restore, - wan_config=ocp.args.JsonRestore(), - ), - ) - max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") - return restored_checkpoint, step - - def load_diffusers_checkpoint(self): - pipeline = WanPipeline2_2.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer - if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): - opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] - elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): - opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step - - def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - return json.loads(model_or_config.to_json_string()) - - max_logging.log(f"Saving checkpoint for step {train_step}") - items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), - } - - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") - def save_checkpoint_orig(self, train_step, pipeline, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py new file mode 100644 index 000000000..7d43582b6 --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -0,0 +1,95 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import json +import jax +import numpy as np +from typing import Optional, Tuple +from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 +from .. import max_logging +import orbax.checkpoint as ocp +from etils import epath +from checkpointing.wan_checkpointer import WanCheckpointer + +class WanCheckpointer2_1(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + transformer_metadata = metadatas.wan_state + abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) + params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_params, + ) + ) + + max_logging.log("Restoring WAN checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + wan_state=params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline2_1.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint.wan_state.keys(): + opt_state = restored_checkpoint.wan_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items["wan_state"] = ocp.args.PyTreeSave(train_states) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") \ No newline at end of file diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py new file mode 100644 index 000000000..502cbfbfa --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -0,0 +1,114 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import json +import jax +import numpy as np +from typing import Optional, Tuple +from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 +from .. import max_logging +import orbax.checkpoint as ocp +from etils import epath +from checkpointing.wan_checkpointer import WanCheckpointer + +class WanCheckpointer2_2(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + # Handle low_noise_transformer + low_noise_transformer_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + low_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_low_params, + ) + ) + + # Handle high_noise_transformer + high_noise_transformer_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + high_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_params, + ) + ) + + max_logging.log("Restoring WAN 2.2 checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + low_noise_transformer_state=low_params_restore, + high_noise_transformer_state=high_params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline2_2.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) + # Check for optimizer state in either transformer + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), + } + + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") \ No newline at end of file diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index f03864da0..51fe2b8dc 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -44,4 +44,7 @@ KEEP_2 = "activation_keep_2" CONV_OUT = "activation_conv_out_channels" -WAN_MODEL = "Wan2.1" +WAN2_1 = "wan2.1" +WAN2_2 = "wan2.2" + +WAN_MODEL = WAN2_1 diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index c49c333f2..ffdf02eb2 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -276,8 +276,16 @@ width: 832 num_frames: 81 flow_shift: 3.0 +# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py +# guidance scale factor for low noise transformer guidance_scale_low: 3.0 + +# guidance scale factor for high noise transformer guidance_scale_high: 4.0 + +# The timestep threshold. If `t` is at or above this value, +# the `high_noise_model` is considered as the required model. +# timestep to switch between low noise and high noise transformer boundary_timestep: 875 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 071db231c..053e5f790 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -23,6 +23,7 @@ from maxdiffusion.utils import export_to_video from google.cloud import storage import flax +from maxdiffusion.common_types import WAN2_1, WAN2_2 def upload_video_to_gcs(output_dir: str, video_path: str): @@ -77,7 +78,7 @@ def get_git_commit_hash(): def call_pipeline(config, pipeline, prompt, negative_prompt): model_key = config.model_name - if model_key == "wan2.1": + if model_key == WAN2_1: return pipeline( prompt=prompt, negative_prompt=negative_prompt, @@ -87,7 +88,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, ) - elif model_key == "wan2.2": + elif model_key == WAN2_2: return pipeline( prompt=prompt, negative_prompt=negative_prompt, @@ -139,9 +140,9 @@ def run(config, pipeline=None, filename_prefix=""): max_logging.log("Could not retrieve Git commit hash.") if pipeline is None: - if model_key == "wan2.1": + if model_key == WAN2_1: checkpoint_loader = WanCheckpointer2_1(config=config) - elif model_key == "wan2.2": + elif model_key == WAN2_2: checkpoint_loader = WanCheckpointer2_2(config=config) else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 85be96977..8df815037 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -504,7 +504,7 @@ def _get_num_channel_latents(self) -> int: """Returns the number of input channels for the transformer.""" pass - def _prepare_call_inputs( + def _prepare_model_inputs( self, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, @@ -572,230 +572,7 @@ def _prepare_call_inputs( def __call__(self, **kwargs): """Runs the inference pipeline.""" pass - -class WanPipeline2_1(WanPipeline): - """Pipeline for WAN 2.1 with a single transformer.""" - def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): - super().__init__(config=config, **kwargs) - self.transformer = transformer - - @classmethod - def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): - common_components = cls._create_common_components(config, vae_only) - transformer = None - if not vae_only: - if load_transformer: - transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer" - ) - - pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - transformer=transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - config=config, - ) - - return pipeline, transformer - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) - transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) - return pipeline - - def _get_num_channel_latents(self) -> int: - return self.transformer.config.in_channels - - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: Optional[jax.Array] = None, - prompt_embeds: Optional[jax.Array] = None, - negative_prompt_embeds: Optional[jax.Array] = None, - vae_only: bool = False, - ): - latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( - prompt, - negative_prompt, - height, - width, - num_frames, - num_inference_steps, - num_videos_per_prompt, - max_sequence_length, - latents, - prompt_embeds, - negative_prompt_embeds, - vae_only, - ) - - graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) - - p_run_inference = partial( - run_inference_2_1, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - ) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - graphdef=graphdef, - sharded_state=state, - rest_of_state=rest_of_state, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - latents = self._denormalize_latents(latents) - return self._decode_latents_to_video(latents) - -class WanPipeline2_2(WanPipeline): - """Pipeline for WAN 2.2 with dual transformers.""" - def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): - super().__init__(config=config, **kwargs) - self.low_noise_transformer = low_noise_transformer - self.high_noise_transformer = high_noise_transformer - - @classmethod - def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): - common_components = cls._create_common_components(config, vae_only) - low_noise_transformer, high_noise_transformer = None, None - if not vae_only and load_transformer: - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer" - ) - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer_2" - ) - - pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - config=config, - ) - return pipeline, low_noise_transformer, high_noise_transformer - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) - low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) - high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) - return pipeline - - def _get_num_channel_latents(self) -> int: - return self.low_noise_transformer.config.in_channels - - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): - latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( - prompt, - negative_prompt, - height, - width, - num_frames, - num_inference_steps, - num_videos_per_prompt, - max_sequence_length, - latents, - prompt_embeds, - negative_prompt_embeds, - vae_only, - ) - - low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) - high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) - - p_run_inference = partial( - run_inference_2_2, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - ) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, - low_noise_state=low_noise_state, - low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, - high_noise_state=high_noise_state, - high_noise_rest=high_noise_rest, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - latents = self._denormalize_latents(latents) - return self._decode_latents_to_video(latents) - + @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( graphdef, @@ -817,93 +594,3 @@ def transformer_forward_pass( latents = latents[:bsz] return noise_pred, latents - -def run_inference_2_1( - graphdef, - sharded_state, - rest_of_state, - latents: jnp.array, - prompt_embeds: jnp.array, - negative_prompt_embeds: jnp.array, - guidance_scale: float, - num_inference_steps: int, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state, -): - do_classifier_free_guidance = guidance_scale > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - - noise_pred, latents = transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance=do_classifier_free_guidance, - guidance_scale=guidance_scale, - ) - - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents - -def run_inference_2_2( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents: jnp.array, - prompt_embeds: jnp.array, - negative_prompt_embeds: jnp.array, - guidance_scale_low: float, - guidance_scale_high: float, - boundary: int, - num_inference_steps: int, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state, -): - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - - def low_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_low - ) - - def high_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_high - ) - - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - - use_high_noise = jnp.greater_equal(t, boundary) - - noise_pred, latents = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latents, timestep, prompt_embeds) - ) - - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py new file mode 100644 index 000000000..bef54b740 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -0,0 +1,162 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .wan_pipeline import WanPipeline, transformer_forward_pass +from ...models.wan.transformers.transformer_wan import WanModel +from typing import List, Union, Optional +from ...pyconfig import HyperParameters +from functools import partial +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + +class WanPipeline2_1(WanPipeline): + """Pipeline for WAN 2.1 with a single transformer.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.transformer = transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + transformer = None + if not vae_only: + if load_transformer: + transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + + return pipeline, transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def _get_num_channel_latents(self) -> int: + return self.transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_1, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) + +def run_inference_2_1( + graphdef, + sharded_state, + rest_of_state, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale: float, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, latents.shape[0]) + + noise_pred, latents = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale, + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py new file mode 100644 index 000000000..16d3861bc --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -0,0 +1,202 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .wan_pipeline import WanPipeline, transformer_forward_pass +from ...models.wan.transformers.transformer_wan import WanModel +from typing import List, Union, Optional +from ...pyconfig import HyperParameters +from functools import partial +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + +class WanPipeline2_2(WanPipeline): + """Pipeline for WAN 2.2 with dual transformers.""" + def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + low_noise_transformer, high_noise_transformer = None, None + if not vae_only and load_transformer: + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2" + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + return pipeline, low_noise_transformer, high_noise_transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) + low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) + high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def _get_num_channel_latents(self) -> int: + return self.low_noise_transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_2, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) + +def run_inference_2_2( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + def low_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + low_noise_graphdef, low_noise_state, low_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_low + ) + + def high_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + high_noise_graphdef, high_noise_state, high_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_high + ) + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, latents.shape[0]) + + use_high_noise = jnp.greater_equal(t, boundary) + + # Selects the model based on the current timestep: + # - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise). + # - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise). + noise_pred, latents = jax.lax.cond( + use_high_noise, + high_noise_branch, + low_noise_branch, + (latents, timestep, prompt_embeds) + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents \ No newline at end of file diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 56eeae766..5eb3dc4ec 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,7 +27,24 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH -from maxdiffusion.common_types import LENGTH, KV_LENGTH +from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2 + +_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2} +_ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1} + +def _validate_model_name(model_name: str | None): + """Raise if model_name is not in the allowed list.""" + if model_name is None: + return + if model_name not in _ALLOWED_MODEL_NAMES: + raise ValueError(f"Invalid config.model_name '{model_name}'. Allowed values: {sorted(_ALLOWED_MODEL_NAMES)}") + +def _validate_training_model_name(model_name: str | None): + """Raise if model_name is not in the allowed training list.""" + if model_name is None: + return + if model_name not in _ALLOWED_TRAINING_MODEL_NAMES: + raise ValueError(f"Invalid config.model_name '{model_name}' for training. Allowed values: {sorted(_ALLOWED_TRAINING_MODEL_NAMES)}") def string_to_bool(s: str) -> bool: @@ -261,6 +278,9 @@ def get_keys(self): def initialize(argv, **kwargs): global _config, config _config = _HyperParameters(argv, **kwargs) + _validate_model_name(_config.keys.get("model_name") if hasattr(_config, "keys") else None) + if kwargs.get("validate_training", False): + _validate_training_model_name(_config.keys.get("model_name") if hasattr(_config, "keys") else None) config = HyperParameters() diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index 79f050c07..984e80241 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -13,8 +13,8 @@ import unittest from unittest.mock import patch, MagicMock - -from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 +from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 +from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 class WanCheckpointer2_1Test(unittest.TestCase): """Tests for WAN 2.1 checkpointer.""" diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index 3b45ea8fd..fea157203 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -31,7 +31,7 @@ def train(config): def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) + pyconfig.initialize(argv, validate_training=True) config = pyconfig.config validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 76b5a2e88..6ebd5342d 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -91,9 +91,6 @@ def __init__(self, config): if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") self.config = config - model_key = config.model_name - if model_key != 'wan2.1': - raise ValueError(f"Unsupported model_name: '{model_key}'. This trainer only supports 'wan2.1'.") self.checkpointer = WanCheckpointer2_1(config=config) def post_training_steps(self, pipeline, params, train_states, msg=""): From 77d52b4f62d11d9db59ce53022f33275f6c35e5f Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 3 Dec 2025 14:07:56 +0530 Subject: [PATCH 24/28] ruff errors --- src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py | 2 +- src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py | 4 ++-- src/maxdiffusion/pyconfig.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py index 7d43582b6..31a3750de 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -92,4 +92,4 @@ def config_to_json(model_or_config): # Save the checkpoint self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") \ No newline at end of file + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 502cbfbfa..9fdfce160 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -111,4 +111,4 @@ def config_to_json(model_or_config): # Save the checkpoint self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") \ No newline at end of file + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 8df815037..153c225db 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -572,7 +572,7 @@ def _prepare_model_inputs( def __call__(self, **kwargs): """Runs the inference pipeline.""" pass - + @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( graphdef, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index bef54b740..8aedddd33 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -159,4 +159,4 @@ def run_inference_2_1( ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents \ No newline at end of file + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 16d3861bc..9eb8c3e90 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -142,7 +142,7 @@ def __call__( ) latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) - + def run_inference_2_2( low_noise_graphdef, low_noise_state, @@ -199,4 +199,4 @@ def high_noise_branch(operands): ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents \ No newline at end of file + return latents diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 5eb3dc4ec..9488f106c 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -38,7 +38,7 @@ def _validate_model_name(model_name: str | None): return if model_name not in _ALLOWED_MODEL_NAMES: raise ValueError(f"Invalid config.model_name '{model_name}'. Allowed values: {sorted(_ALLOWED_MODEL_NAMES)}") - + def _validate_training_model_name(model_name: str | None): """Raise if model_name is not in the allowed training list.""" if model_name is None: From f07503d5f01e2afd218f42d28f6c7965df4a37cb Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 3 Dec 2025 14:20:26 +0530 Subject: [PATCH 25/28] pytest errors fixed --- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 6ebd5342d..4369d6d06 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -29,7 +29,7 @@ from maxdiffusion.schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning from maxdiffusion import max_utils, max_logging, train_utils -from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1 +from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.generate_wan import run as generate_wan from maxdiffusion.generate_wan import inference_generate_video From 5c05cc3335e57e57f8dae3df963bf396d457ad5b Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 3 Dec 2025 14:28:10 +0530 Subject: [PATCH 26/28] pytest errors fixed --- src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py | 2 +- src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py index 31a3750de..a8e2a2974 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -22,7 +22,7 @@ from .. import max_logging import orbax.checkpoint as ocp from etils import epath -from checkpointing.wan_checkpointer import WanCheckpointer +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer class WanCheckpointer2_1(WanCheckpointer): diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 9fdfce160..30cff3871 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -22,7 +22,7 @@ from .. import max_logging import orbax.checkpoint as ocp from etils import epath -from checkpointing.wan_checkpointer import WanCheckpointer +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer class WanCheckpointer2_2(WanCheckpointer): From efc60a7ab3afcf4df41ace648bd2d68f883b7b68 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 3 Dec 2025 15:02:52 +0530 Subject: [PATCH 27/28] pytest errors fixed --- src/maxdiffusion/generate_wan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 053e5f790..e3365e961 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -17,7 +17,8 @@ import time import os import subprocess -from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 +from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 +from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 from maxdiffusion import pyconfig, max_logging, max_utils from absl import app from maxdiffusion.utils import export_to_video From 9ba8d09b89b8afa9dde73c2c94fb0a803a9ef848 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 3 Dec 2025 17:02:55 +0530 Subject: [PATCH 28/28] wan_checkpointer_test.py fixed --- .../tests/wan_checkpointer_test.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index 984e80241..719716d74 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -25,7 +25,7 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = None @@ -44,7 +44,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -74,7 +74,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -114,7 +114,7 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): """Test loading from pretrained when no checkpoint exists.""" mock_manager = MagicMock() @@ -134,7 +134,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): """Test loading checkpoint without optimizer state.""" mock_manager = MagicMock() @@ -167,7 +167,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mock_create_manager): """Test loading checkpoint with optimizer state in low_noise_transformer.""" mock_manager = MagicMock() @@ -201,7 +201,7 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mo self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, mock_create_manager): """Test loading checkpoint with optimizer state in high_noise_transformer.""" mock_manager = MagicMock() @@ -244,7 +244,7 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_create_manager): """Test loading checkpoint with explicit None step falls back to latest.""" mock_manager = MagicMock() @@ -270,7 +270,7 @@ def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_c self.assertEqual(step, 5) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") def test_load_checkpoint_both_optimizers_present(self, mock_wan_pipeline, mock_create_manager): """Test loading checkpoint when both transformers have optimizer state (prioritize low_noise).""" mock_manager = MagicMock()