diff --git a/README.md b/README.md index 081d65e8e..e14603ac4 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,7 @@ To generate images, run the following command: ```bash python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` + ## Flux First make sure you have permissions to access the Flux repos in Huggingface. diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py index 07b5585a6..20fd0d8ae 100644 --- a/end_to_end/tpu/eval_assert.py +++ b/end_to_end/tpu/eval_assert.py @@ -22,7 +22,6 @@ """ - # pylint: skip-file """Reads and asserts over target values""" from absl import app @@ -47,7 +46,7 @@ def test_final_loss(metrics_file, target_loss, num_samples_str="10"): target_loss = float(target_loss) num_samples = int(num_samples_str) with open(metrics_file, "r", encoding="utf8") as _: - last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples) + last_n_data = get_last_n_data(metrics_file, "learning/loss", num_samples) avg_last_n_data = sum(last_n_data) / len(last_n_data) print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}") print(f"Target loss is {target_loss}") diff --git a/requirements.txt b/requirements.txt index defbb1512..e26b45b80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,6 @@ huggingface_hub==0.30.2 transformers==4.48.1 einops==0.8.0 sentencepiece -aqtp \ No newline at end of file +aqtp +imageio==2.37.0 +imageio-ffmpeg==0.6.0 \ No newline at end of file diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index 80ad1434e..5a88c800f 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -31,4 +31,6 @@ tensorflow-datasets>=4.9.6 tokenizers==0.21.0 torch==2.5.1 torchvision==0.20.1 -transformers==4.48.1 \ No newline at end of file +transformers==4.48.1 +imageio==2.37.0 +imageio-ffmpeg==0.6.0 \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml new file mode 100644 index 000000000..28ef6e77e --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -0,0 +1,269 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' + +# Flux params +flux_name: "flux-dev" +max_sequence_length: 512 +time_shift: True +base_shift: 0.5 +max_shift: 1.15 +# offloads t5 encoder after text encoding to save memory. +offload_encoders: True + + +unet_checkpoint: '' +revision: 'refs/pr/95' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te + +flash_block_sizes: {} +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +# flash_block_sizes: { +# "block_q" : 1536, +# "block_kv_compute" : 1536, +# "block_kv" : 1536, +# "block_q_dkv" : 1536, +# "block_kv_dkv" : 1536, +# "block_kv_dkv_compute" : 1536, +# "block_q_dq" : 1536, +# "block_kv_dq" : 1536 +# } +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_unet: False + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 4.e-7 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 200 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1 + +warmup_steps_fraction: 0.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 1.e-2 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A magical castle in the middle of a forest, artistic drawing" +prompt_2: "A magical castle in the middle of a forest, artistic drawing" +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 3.5 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 50 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 8bcb6d80d..bfd420f72 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -464,7 +464,8 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove flax internal keys if hasattr(cls, "_flax_internal_args"): for arg in cls._flax_internal_args: - expected_keys.remove(arg) + if arg in expected_keys: + expected_keys.remove(arg) # 2. Remove attributes that cannot be expected from expected config attributes # remove keys to be ignored diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 788e1c94e..76fa7635e 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -36,6 +36,53 @@ ] +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + + class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 839903406..2f8946056 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -14,7 +14,7 @@ import functools import math - +from typing import Optional import flax.linen as nn import jax import jax.numpy as jnp @@ -406,6 +406,142 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back +def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) + + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + + return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + + +class FlaxWanAttention(nn.Module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + attention_kernel: str = "dot_product" + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) + precision: jax.lax.Precision = None + qkv_bias: bool = False + + def setup(self): + if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + inner_dim = self.dim_head * self.heads + scale = self.dim_head**-0.5 + + self.attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=scale, + heads=self.heads, + dim_head=self.dim_head, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype, + float32_qk_product=False, + ) + + kernel_axes = ("embed", "heads") + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) + + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "heads")) + + self.query = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_q", + precision=self.precision, + ) + + self.key = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_k", + precision=self.precision, + ) + + self.value = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_v", + precision=self.precision, + ) + + self.query_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + self.key_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + + self.proj_attn = nn.Dense( + self.query_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("heads", "embed")), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_out_0", + precision=self.precision, + ) + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def call( + self, + hidden_states: Array, + encoder_hidden_states: Optional[Array], + rotary_emb: Optional[Array], + deterministic: bool = True, + ): + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_proj = self.query(hidden_states) + key_proj = self.key(encoder_hidden_states) + value_proj = self.value(encoder_hidden_states) + + query_proj = self.query_norm(query_proj) + key_proj = self.key_norm(key_proj) + + if rotary_emb: + query_proj, key_proj = self.apply_rope(query_proj, key_proj, rotary_emb) + + query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) + key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) + value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + + hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + + hidden_states = self.proj_attn(hidden_states) + hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) + return self.dropout_layer(hidden_states, deterministic=deterministic) + + class FlaxFluxAttention(nn.Module): query_dim: int heads: int = 8 @@ -515,15 +651,6 @@ def setup(self): param_dtype=self.weights_dtype, ) - def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: - xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) - - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - - return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) - def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): qkv_proj = self.qkv(hidden_states) @@ -557,7 +684,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) - query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb) + query_proj, key_proj = apply_rope(query_proj, key_proj, image_rotary_emb) query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 42ca4b950..cc961e131 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -91,7 +91,8 @@ class FlaxTimesteps(nn.Module): dim: int = 32 flip_sin_to_cos: bool = False - freq_shift: float = 1 + freq_shift: float = 1.0 + scale: int = 1 @nn.compact def __call__(self, timesteps): diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 362a39171..d856fda5d 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -4,14 +4,13 @@ import jax from jax.typing import DTypeLike -import torch # need for torch 2 jax from chex import Array from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import hf_hub_download from jax import numpy as jnp from safetensors import safe_open -from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) +from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax) from maxdiffusion import max_logging @@ -33,20 +32,6 @@ class FluxParams: param_dtype: DTypeLike -def torch2jax(torch_tensor: torch.Tensor) -> Array: - is_bfloat16 = torch_tensor.dtype == torch.bfloat16 - if is_bfloat16: - # upcast the tensor to fp32 - torch_tensor = torch_tensor.float() - - if torch.device.type != "cpu": - torch_tensor = torch_tensor.to("cpu") - - numpy_value = torch_tensor.numpy() - jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) - return jax_array - - @dataclass class ModelSpec: params: FluxParams diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 9552c69f1..d6a448f98 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -15,19 +15,61 @@ """ PyTorch - Flax general utilities.""" import re +import torch import jax import jax.numpy as jnp from flax.linen import Partitioned from flax.traverse_util import flatten_dict, unflatten_dict from flax.core.frozen_dict import unfreeze from jax.random import PRNGKey - +from chex import Array from ..utils import logging +from .. import max_logging logger = logging.get_logger(__name__) +def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): + """ + expected_pytree: dict - a pytree that comes from initializing the model. + new_pytree: dict - a pytree that has been created from pytorch weights. + """ + expected_pytree = flatten_dict(expected_pytree) + if len(expected_pytree.keys()) != len(new_pytree.keys()): + set1 = set(expected_pytree.keys()) + set2 = set(new_pytree.keys()) + missing_keys = set1 ^ set2 + max_logging.log(f"missing keys : {missing_keys}") + for key in expected_pytree.keys(): + if key in new_pytree.keys(): + try: + expected_pytree_shape = expected_pytree[key].shape + except Exception: + expected_pytree_shape = expected_pytree[key].value.shape + if expected_pytree_shape != new_pytree[key].shape: + max_logging.log(f"shape mismatch for {key}") + max_logging.log( + f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}" + ) + else: + max_logging.log(f"key: {key} not found...") + + +def torch2jax(torch_tensor: torch.Tensor) -> Array: + is_bfloat16 = torch_tensor.dtype == torch.bfloat16 + if is_bfloat16: + # upcast the tensor to fp32 + torch_tensor = torch_tensor.float() + + if torch.device.type != "cpu": + torch_tensor = torch_tensor.to("cpu") + + numpy_value = torch_tensor.numpy() + jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) + return jax_array + + def rename_key(key): regex = r"\w+[.]\d+" pats = re.findall(regex, key) @@ -94,6 +136,12 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor + # 3d conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 5: + pt_tensor = pt_tensor.transpose(2, 3, 4, 1, 0) + return renamed_pt_tuple_key, pt_tensor + # linear layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight": @@ -103,6 +151,8 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic # old PyTorch layer norm weight renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) if pt_tuple_key[-1] == "gamma": + renamed_pt_tuple_key = pt_tuple_key + pt_tensor = pt_tensor.flatten() return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm bias diff --git a/src/maxdiffusion/models/wan/__init__.py b/src/maxdiffusion/models/wan/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/wan/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py new file mode 100644 index 000000000..8325c3707 --- /dev/null +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -0,0 +1,853 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Tuple, List, Sequence, Union, Optional + +import jax +import jax.numpy as jnp +from flax import nnx +from ...configuration_utils import ConfigMixin +from ..modeling_flax_utils import FlaxModelMixin +from ... import common_types +from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) + +BlockSizes = common_types.BlockSizes + +CACHE_T = 2 + +_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish} + + +def get_activation(name: str): + func = _ACTIVATIONS.get(name) + if func is None: + raise ValueError(f"Unknown activation function: {name}") + return func + + +# Helper to ensure kernel_size, stride, padding are tuples of 3 integers +def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: + """Canonicalizes a value to a tuple of integers.""" + if isinstance(x, int): + return (x,) * rank + elif isinstance(x, Sequence) and len(x) == rank: + return tuple(x) + else: + raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") + + +class WanCausalConv3d(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, # rngs are required for initializing parameters, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + use_bias: bool = True, + ): + self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") + self.stride = _canonicalize_tuple(stride, 3, "stride") + padding_tuple = _canonicalize_tuple(padding, 3, "padding") # (D, H, W) padding amounts + + self._causal_padding = ( + (0, 0), # Batch dimension - no padding + (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) + (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding + (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding + (0, 0), # Channel dimension - no padding + ) + + # Store the amount of padding needed *before* the depth dimension for caching logic + self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + + self.conv = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + use_bias=use_bias, + padding="VALID", # Handle padding manually + rngs=rngs, + ) + + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: + current_padding = list(self._causal_padding) # Mutable copy + padding_needed = self._depth_padding_before + + if cache_x is not None and padding_needed > 0: + # Ensure cache has same spatial/channel dims, potentially different depth + assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" + cache_len = cache_x.shape[1] + x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) + + padding_needed -= cache_len + if padding_needed < 0: + # Cache longer than needed padding, trim from start + x = x[:, -padding_needed:, ...] + current_padding[1] = (0, 0) # No explicit padding needed now + else: + # Update depth padding needed + current_padding[1] = (padding_needed, 0) + + # Apply padding if any dimension requires it + padding_to_apply = tuple(current_padding) + if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): + x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) + else: + x_padded = x + out = self.conv(x_padded) + return out + + +class WanRMS_norm(nnx.Module): + + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + channel_first: bool = True, + images: bool = True, + eps: float = 1e-6, + use_bias: bool = False, + ): + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + self.eps = eps + self.channel_first = channel_first + self.scale = dim**0.5 + # Initialize gamma as parameter + self.gamma = nnx.Param(jnp.ones(shape)) + if use_bias: + self.bias = nnx.Param(jnp.zeros(shape)) + else: + self.bias = 0 + + def __call__(self, x: jax.Array) -> jax.Array: + normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True) + normalized = x / jnp.maximum(normalized, self.eps) + normalized = normalized * self.scale * self.gamma + if self.bias: + return normalized + self.bias.value + return normalized + + +class WanUpsample(nnx.Module): + + def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"): + # scale_factor for (H, W) + # JAX resize works on spatial dims, H, W assumming (N, D, H, W, C) or (N, H, W, C) + self.scale_factor = scale_factor + self.method = method + + def __call__(self, x: jax.Array) -> jax.Array: + input_dtype = x.dtype + in_shape = x.shape + assert len(in_shape) == 4, "This module only takes tensors with shape of 4." + n, h, w, c = in_shape + target_h = int(h * self.scale_factor[0]) + target_w = int(w * self.scale_factor[1]) + out = jax.image.resize(x.astype(jnp.float32), (n, target_h, target_w, c), method=self.method) + return out.astype(input_dtype) + + +class Identity(nnx.Module): + + def __call__(self, x): + return x + + +class ZeroPaddedConv2D(nnx.Module): + """ + Module for adding padding before conv. + Currently it does not add any padding. + """ + + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + ): + self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) + + def __call__(self, x): + return self.conv(x) + + +class WanResample(nnx.Module): + + def __init__( + self, + dim: int, + mode: str, + rngs: nnx.Rngs, + ): + self.dim = dim + self.mode = mode + self.time_conv = None + + if mode == "upsample2d": + self.resample = nnx.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(3, 3), + padding="SAME", + use_bias=True, + rngs=rngs, + ), + ) + elif mode == "upsample3d": + self.resample = nnx.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(3, 3), + padding="SAME", + use_bias=True, + rngs=rngs, + ), + ) + self.time_conv = WanCausalConv3d( + rngs=rngs, + in_channels=dim, + out_channels=dim * 2, + kernel_size=(3, 1, 1), + padding=(1, 0, 0), + ) + elif mode == "downsample2d": + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) + elif mode == "downsample3d": + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) + self.time_conv = WanCausalConv3d( + rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + else: + self.resample = Identity() + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: + # Input x: (N, D, H, W, C), assume C = self.dim + b, t, h, w, c = x.shape + assert c == self.dim + + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, t, h, w, 2, c) + x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1) + x = x.reshape(b, t * 2, h, w, c) + t = x.shape[1] + x = x.reshape(b * t, h, w, c) + x = self.resample(x) + h_new, w_new, c_new = x.shape[1:] + x = x.reshape(b, t, h_new, w_new, c_new) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = jnp.copy(x) + feat_idx[0] += 1 + else: + cache_x = jnp.copy(x[:, -1:, :, :, :]) + x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + return x + + +class WanResidualBlock(nnx.Module): + + def __init__( + self, + in_dim: int, + out_dim: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + non_linearity: str = "silu", + ): + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) + self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) + self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.conv_shortcut = ( + WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1) + if in_dim != out_dim + else Identity() + ) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv1(x, feat_cache[idx], idx) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + x = self.norm2(x) + x = self.nonlinearity(x) + idx = feat_idx[0] + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + x = x + h + return x + + +class WanAttentionBlock(nnx.Module): + + def __init__(self, dim: int, rngs: nnx.Rngs): + self.dim = dim + self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) + self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs) + self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs) + + def __call__(self, x: jax.Array): + + identity = x + batch_size, time, height, width, channels = x.shape + + x = x.reshape(batch_size * time, height, width, channels) + x = self.norm(x) + + qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) + # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) + qkv = jnp.transpose(qkv, (0, 1, 3, 2)) + # q, k, v = jnp.split(qkv, 3, axis=-1) + q, k, v = jnp.split(qkv, 3, axis=-2) + q = jnp.transpose(q, (0, 1, 3, 2)) + k = jnp.transpose(k, (0, 1, 3, 2)) + v = jnp.transpose(v, (0, 1, 3, 2)) + x = jax.nn.dot_product_attention(q, k, v) + x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) + + # output projection + x = self.proj(x) + # Reshape back + x = x.reshape(batch_size, time, height, width, channels) + + return x + identity + + +class WanMidBlock(nnx.Module): + + def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + self.dim = dim + resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) + resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)) + self.attentions = attentions + self.resnets = resnets + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + x = self.resnets[0](x, feat_cache, feat_idx) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + x = resnet(x, feat_cache, feat_idx) + return x + + +class WanUpBlock(nnx.Module): + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs) + ) + current_dim = out_dim + self.resnets = resnets + + # Add upsampling layer if needed. + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)] + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class WanEncoder3d(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d( + rngs=rngs, + in_channels=3, + out_channels=dims[0], + kernel_size=3, + padding=1, + ) + + # downsample blocks + self.down_blocks = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim=in_dim, out_dim=out_dim, dropout=dropout, rngs=rngs)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(dim=out_dim, rngs=rngs)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) + scale /= 2.0 + + # middle_blocks + self.mid_block = WanMidBlock( + dim=out_dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, + ) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs) + self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + x = self.mid_block(x, feat_cache, feat_idx) + + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class WanDecoder3d(nnx.Module): + r""" + A 3D decoder module. + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 4, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales=List[float], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = WanCausalConv3d(rngs=rngs, in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1) + + # middle_blocks + self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + # Create and add the upsampling block + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + rngs=rngs, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) + self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=3, kernel_size=3, padding=1) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLWanCache: + + def __init__(self, module): + self.module = module + self.clear_cache() + + def clear_cache(self): + """Resets cache dictionaries and indices""" + + def _count_conv3d(module): + count = 0 + node_types = nnx.graph.iter_graph([module]) + for _, value in node_types: + if isinstance(value, WanCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.module.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.module.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): + + def __init__( + self, + rngs: nnx.Rngs, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + ): + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temporal_upsample = temperal_downsample[::-1] + self.latents_mean = latents_mean + self.latents_std = latents_std + + self.encoder = WanEncoder3d( + rngs=rngs, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + ) + self.quant_conv = WanCausalConv3d(rngs=rngs, in_channels=z_dim * 2, out_channels=z_dim * 2, kernel_size=1) + self.post_quant_conv = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim, + out_channels=z_dim, + kernel_size=1, + ) + + self.decoder = WanDecoder3d( + rngs=rngs, + dim=base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temporal_upsample, + dropout=dropout, + ) + + def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): + feat_cache.clear_cache() + if x.shape[-1] != 3: + # reshape channel last for JAX + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" + + t = x.shape[1] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + feat_cache._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :1, :, :, :], feat_cache=feat_cache._enc_feat_map, feat_idx=feat_cache._enc_conv_idx) + else: + out_ = self.encoder( + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], + feat_cache=feat_cache._enc_feat_map, + feat_idx=feat_cache._enc_conv_idx, + ) + out = jnp.concatenate([out, out_], axis=1) + enc = self.quant_conv(out) + mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] + enc = jnp.concatenate([mu, logvar], axis=-1) + feat_cache.clear_cache() + return enc + + def encode( + self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True + ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: + """Encode video into latent distribution.""" + h = self._encode(x, feat_cache) + posterior = FlaxDiagonalGaussianDistribution(h) + if not return_dict: + return (posterior,) + return FlaxAutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True + ) -> Union[FlaxDecoderOutput, jax.Array]: + feat_cache.clear_cache() + iter_ = z.shape[1] + x = self.post_quant_conv(z) + for i in range(iter_): + feat_cache._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + else: + out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + + # This is to bypass an issue where frame[1] should be frame[2] and vise versa. + # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. + # Most likely due to an incorrect reshaping in the decoder. + fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] + if len(fm1.shape) == 4: + fm1 = jnp.expand_dims(fm1, axis=0) + fm2 = jnp.expand_dims(fm2, axis=0) + fm3 = jnp.expand_dims(fm3, axis=0) + fm4 = jnp.expand_dims(fm4, axis=0) + + out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) + out = jnp.clip(out, min=-1.0, max=1.0) + feat_cache.clear_cache() + if not return_dict: + return (out,) + + return FlaxDecoderOutput(sample=out) + + def decode( + self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True + ) -> Union[FlaxDecoderOutput, jax.Array]: + if z.shape[-1] != self.z_dim: + # reshape channel last for JAX + z = jnp.transpose(z, (0, 2, 3, 4, 1)) + assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}" + decoded = self._decode(z, feat_cache).sample + if not return_dict: + return (decoded,) + return FlaxDecoderOutput(sample=decoded) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py new file mode 100644 index 000000000..1a9948fdb --- /dev/null +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -0,0 +1,75 @@ +import jax +import jax.numpy as jnp +from maxdiffusion import max_logging +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from flax.traverse_util import unflatten_dict +from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) + + +def _tuple_str_to_int(in_tuple): + out_list = [] + for item in in_tuple: + try: + out_list.append(int(item)) + except ValueError: + out_list.append(item) + return tuple(out_list) + + +def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + ckpt_path = hf_hub_download( + pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors" + ) + max_logging.log(f"Load and port Wan 2.1 VAE on {device}") + + if ckpt_path is not None: + tensors = {} + with safe_open(ckpt_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + # Order matters + renamed_pt_key = renamed_pt_key.replace("up_blocks_", "up_blocks.") + renamed_pt_key = renamed_pt_key.replace("mid_block_", "mid_block.") + renamed_pt_key = renamed_pt_key.replace("down_blocks_", "down_blocks.") + + renamed_pt_key = renamed_pt_key.replace("conv_in.bias", "conv_in.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv_in.weight", "conv_in.conv.weight") + renamed_pt_key = renamed_pt_key.replace("conv_out.bias", "conv_out.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv_out.weight", "conv_out.conv.weight") + renamed_pt_key = renamed_pt_key.replace("attentions_", "attentions.") + renamed_pt_key = renamed_pt_key.replace("resnets_", "resnets.") + renamed_pt_key = renamed_pt_key.replace("upsamplers_", "upsamplers.") + renamed_pt_key = renamed_pt_key.replace("resample_", "resample.") + renamed_pt_key = renamed_pt_key.replace("conv1.bias", "conv1.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv1.weight", "conv1.conv.weight") + renamed_pt_key = renamed_pt_key.replace("conv2.bias", "conv2.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv2.weight", "conv2.conv.weight") + renamed_pt_key = renamed_pt_key.replace("time_conv.bias", "time_conv.conv.bias") + renamed_pt_key = renamed_pt_key.replace("time_conv.weight", "time_conv.conv.weight") + renamed_pt_key = renamed_pt_key.replace("quant_conv", "quant_conv.conv") + renamed_pt_key = renamed_pt_key.replace("conv_shortcut", "conv_shortcut.conv") + if "decoder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("resample.1.bias", "resample.layers.1.bias") + renamed_pt_key = renamed_pt_key.replace("resample.1.weight", "resample.layers.1.weight") + if "encoder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("resample.1", "resample.conv") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + else: + raise FileNotFoundError(f"Path {ckpt_path} was not found") + + return flax_state_dict diff --git a/src/maxdiffusion/schedulers/__init__.py b/src/maxdiffusion/schedulers/__init__.py index b630948e0..edd249de1 100644 --- a/src/maxdiffusion/schedulers/__init__.py +++ b/src/maxdiffusion/schedulers/__init__.py @@ -48,6 +48,7 @@ _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"] + _import_structure["scheduling_flow_match_euler_discrete_flax"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_utils_flax"] = [ "FlaxKarrasDiffusionSchedulers", "FlaxSchedulerMixin", @@ -73,6 +74,7 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_flow_match_euler_discrete_flax import FlowMatchEulerDiscreteScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py new file mode 100644 index 000000000..7d750c8bb --- /dev/null +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -0,0 +1,460 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import jax +import jax.numpy as jnp +from flax import nnx +import numpy as np +import unittest +from absl.testing import absltest +from skimage.metrics import structural_similarity as ssim +from ..models.wan.autoencoder_kl_wan import ( + WanCausalConv3d, + WanUpsample, + AutoencoderKLWan, + WanMidBlock, + WanResidualBlock, + WanRMS_norm, + WanResample, + ZeroPaddedConv2D, + WanAttentionBlock, + AutoencoderKLWanCache, +) +from ..models.wan.wan_utils import load_wan_vae +from ..utils import load_video +from ..video_processor import VideoProcessor + +CACHE_T = 2 + + +class TorchWanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class TorchWanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + raise Exception("downsample3d not supported") + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + raise Exception("downsample3d not supported") + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class WanVaeTest(unittest.TestCase): + + def setUp(self): + WanVaeTest.dummy_data = {} + + def test_wanrms_norm(self): + """Test against the Pytorch implementation""" + + # --- Test Case 1: images == True --- + dim = 96 + input_shape = (1, 96, 1, 480, 720) + + model = TorchWanRMS_norm(dim) + input = torch.ones(input_shape) + torch_output = model(input) + torch_output_np = torch_output.detach().numpy() + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs) + dummy_input = jnp.ones(input_shape) + output = wanrms_norm(dummy_input) + output_np = np.array(output) + assert np.allclose(output_np, torch_output_np) is True + + # --- Test Case 2: images == False --- + model = TorchWanRMS_norm(dim, images=False) + input = torch.ones(input_shape) + torch_output = model(input) + torch_output_np = torch_output.detach().numpy() + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs, images=False) + dummy_input = jnp.ones(input_shape) + output = wanrms_norm(dummy_input) + output_np = np.array(output) + assert np.allclose(output_np, torch_output_np) is True + + def test_zero_padded_conv(self): + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + dim = 96 + kernel_size = 3 + stride = (2, 2) + resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride)) + input_shape = (1, 96, 480, 720) + input = torch.ones(input_shape) + output_torch = resample(input) + assert output_torch.shape == (1, 96, 240, 360) + + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + dummy_input = jnp.ones(input_shape) + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) + output = model(dummy_input) + output = jnp.transpose(output, (0, 3, 1, 2)) + assert output.shape == (1, 96, 240, 360) + + def test_wan_upsample(self): + batch_size = 1 + in_depth, in_height, in_width = 10, 32, 32 + in_channels = 3 + + dummy_input = jnp.ones((batch_size * in_depth, in_height, in_width, in_channels)) + + upsample = WanUpsample(scale_factor=(2.0, 2.0)) + + # --- Test Case 1: depth > 1 --- + output = upsample(dummy_input) + assert output.shape == (10, 64, 64, 3) + + def test_wan_resample(self): + # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + # --- Test Case 1: downsample2d --- + batch = 1 + dim = 96 + t = 1 + h = 480 + w = 720 + mode = "downsample2d" + input_shape = (batch, dim, t, h, w) + dummy_input = torch.ones(input_shape) + torch_wan_resample = TorchWanResample(dim=dim, mode=mode) + torch_output = torch_wan_resample(dummy_input) + assert torch_output.shape == (batch, dim, t, h // 2, w // 2) + + wan_resample = WanResample(dim, mode=mode, rngs=rngs) + # channels is always last here + input_shape = (batch, t, h, w, dim) + dummy_input = jnp.ones(input_shape) + output = wan_resample(dummy_input) + assert output.shape == (batch, t, h // 2, w // 2, dim) + + def test_3d_conv(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size = 1 + in_depth, in_height, in_width = 10, 32, 32 + in_channels = 3 + out_channels = 16 + kernel_d, kernel_h, kernel_w = 3, 3, 3 # Kernel size (Depth, Height, Width) + padding_d, padding_h, padding_w = 1, 1, 1 # Base padding (Depth, Height, Width) + + # Create dummy input data + dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + + # Create dummy cache data (from a previous step) + cache_depth = 2 * padding_d + dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) + + # Instantiate the module + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization + ) + + # --- Test Case 1: No Cache --- + output_no_cache = causal_conv_layer(dummy_input) + assert output_no_cache.shape == (1, 10, 32, 32, 16) + + # --- Test Case 2: With Cache --- + output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) + assert output_with_cache.shape == (1, 10, 32, 32, 16) + + # --- Test Case 3: With Cache larger than padding --- + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) + output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) + assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + + def test_wan_residual(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + # --- Test Case 1: same in/out dim --- + in_dim = out_dim = 96 + batch = 1 + t = 1 + height = 480 + width = 720 + dim = 96 + input_shape = (batch, t, height, width, dim) + expected_output_shape = (batch, t, height, width, dim) + + wan_residual_block = WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, + ) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + + # --- Test Case 1: different in/out dim --- + in_dim = 96 + out_dim = 196 + expected_output_shape = (batch, t, height, width, out_dim) + + wan_residual_block = WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, + ) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + + def test_wan_attention(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 384 + batch = 1 + t = 1 + height = 60 + width = 90 + input_shape = (batch, t, height, width, dim) + wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) + dummy_input = jnp.ones(input_shape) + output = wan_attention(dummy_input) + assert output.shape == input_shape + + def test_wan_midblock(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch = 1 + t = 1 + dim = 384 + height = 60 + width = 90 + input_shape = (batch, t, height, width, dim) + wan_midblock = WanMidBlock(dim=dim, rngs=rngs) + dummy_input = jnp.ones(input_shape) + output = wan_midblock(dummy_input) + assert output.shape == input_shape + + def test_wan_decode(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 96 + z_dim = 16 + dim_mult = [1, 2, 4, 4] + num_res_blocks = 2 + attn_scales = [] + temperal_downsample = [False, True, True] + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + t = 13 + channels = 16 + height = 60 + width = 90 + input_shape = (batch, t, height, width, channels) + input = jnp.ones(input_shape) + + latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) + latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) + input = input / latents_std + latents_mean + dummy_output = wan_vae.decode(input, feat_cache=vae_cache) + assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + + def test_wan_encode(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 96 + z_dim = 16 + dim_mult = [1, 2, 4, 4] + num_res_blocks = 2 + attn_scales = [] + temperal_downsample = [False, True, True] + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + channels = 3 + t = 49 + height = 480 + width = 720 + input_shape = (batch, channels, t, height, width) + input = jnp.ones(input_shape) + output = wan_vae.encode(input, feat_cache=vae_cache) + assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + + def test_load_checkpoint(self): + def vae_encode(video, wan_vae, vae_cache, key): + latent = wan_vae.encode(video, feat_cache=vae_cache) + latent = latent.latent_dist.sample(key) + return latent + + pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs) + vae_cache = AutoencoderKLWanCache(wan_vae) + video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + video = load_video(video_path) + + vae_scale_factor_spatial = 2 ** len(wan_vae.temperal_downsample) + video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial) + width, height = video[0].size + video = video_processor.preprocess_video(video, height=height, width=width) # .to(dtype=jnp.float32) + original_video = jnp.array(np.array(video), dtype=jnp.bfloat16) + + graphdef, state = nnx.split(wan_vae) + params = state.to_pure_dict() + # This replaces random params with the model. + params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + wan_vae = nnx.merge(graphdef, params) + + p_vae_encode = jax.jit(functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)) + original_video_shape = original_video.shape + latent = p_vae_encode(original_video) + + jitted_decode = jax.jit(functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)) + video = jitted_decode(latent)[0] + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + assert video.shape == original_video_shape + + original_video = torch.from_numpy(np.array(original_video.astype(jnp.float32))).to(dtype=torch.bfloat16) + video = torch.from_numpy(np.array(video)).to(dtype=torch.bfloat16) + video = video_processor.postprocess_video(video, output_type="np") + original_video = video_processor.postprocess_video(original_video, output_type="np") + ssim_compare = ssim(video[0], original_video[0], multichannel=True, channel_axis=-1, data_range=255) + assert ssim_compare >= 0.9999 + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/utils/__init__.py b/src/maxdiffusion/utils/__init__.py index 199b7c156..256934b3b 100644 --- a/src/maxdiffusion/utils/__init__.py +++ b/src/maxdiffusion/utils/__init__.py @@ -83,7 +83,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image +from .loading_utils import load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( @@ -103,7 +103,6 @@ convert_unet_state_dict_to_peft, ) - logger = get_logger(__name__) diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index cd31a7baa..5dfa3562f 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -3,14 +3,14 @@ import struct import tempfile from contextlib import contextmanager -from typing import List +from typing import List, Optional, Union import numpy as np import PIL.Image import PIL.ImageOps -from .import_utils import BACKENDS_MAPPING, is_opencv_available +from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available from .logging import get_logger @@ -111,7 +111,9 @@ def export_to_obj(mesh, output_obj_path: str = None): f.writelines("\n".join(combined_data)) -def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: +def _legacy_export_to_video( + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 +): if is_opencv_available(): import cv2 else: @@ -119,10 +121,88 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + fourcc = cv2.VideoWriter_fourcc(*"mp4v") h, w, c = video_frames[0].shape - video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) + + return output_video_path + + +def export_to_video( + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], + output_video_path: str = None, + fps: int = 10, + quality: float = 5.0, + bitrate: Optional[int] = None, + macro_block_size: Optional[int] = 16, +) -> str: + """ + quality: + Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to + prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. + Specifying a fixed bitrate using `bitrate` disables this parameter. + + bitrate: + Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. + Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter + rather than specifiying a fixed bitrate with this parameter. + + macro_block_size: + Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number + imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs + are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic + feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some + codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. + """ + # TODO: Dhruv. Remove by Diffusers release 0.33.0 + # Added to prevent breaking existing code + if not is_imageio_available(): + logger.warning( + ( + "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" + "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" + "Support for the OpenCV backend will be deprecated in a future Diffusers version" + ) + ) + return _legacy_export_to_video(video_frames, output_video_path, fps) + + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + ( + "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" + "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" + ) + ) + + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + + with imageio.get_writer( + output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size + ) as writer: + for frame in video_frames: + writer.append_data(frame) + return output_video_path diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index 99d88336a..d83596e8d 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -51,6 +51,21 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + + _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -105,6 +120,7 @@ except importlib_metadata.PackageNotFoundError: _transformers_available = False +_imageio_available, _imageio_version = _is_package_available("imageio") _inflect_available = importlib.util.find_spec("inflect") is not None try: @@ -285,6 +301,10 @@ _peft_available = False +def is_imageio_available(): + return _imageio_available + + def is_torch_available(): return _torch_available @@ -486,6 +506,11 @@ def is_peft_available(): {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` """ +# docstyle-ignore +IMAGEIO_IMPORT_ERROR = """ +{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -506,6 +531,7 @@ def is_peft_available(): ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), ] ) diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 07c08e726..6107272c7 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -1,9 +1,12 @@ import os -from typing import Union +from typing import Callable, List, Optional, Union import PIL.Image import PIL.ImageOps import requests +import tempfile +from urllib.parse import unquote, urlparse +from .import_utils import BACKENDS_MAPPING, is_imageio_available def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: @@ -33,3 +36,87 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image + + +def load_video( + video: str, + convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, +) -> List[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py new file mode 100644 index 000000000..c29485118 --- /dev/null +++ b/src/maxdiffusion/video_processor.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) + + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs