From 15d242eb34e54613da7d5e4ee4dfd04eeb190a3e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Sun, 9 Mar 2025 20:49:12 +0000 Subject: [PATCH 01/25] wip - wan transformer --- src/maxdiffusion/models/attention_flax.py | 144 ++++++++- src/maxdiffusion/models/embeddings_flax.py | 4 +- src/maxdiffusion/models/wan/__init__.py | 15 + .../wan/transformers/transformer_flux_wan.py | 305 ++++++++++++++++++ 4 files changed, 456 insertions(+), 12 deletions(-) create mode 100644 src/maxdiffusion/models/wan/__init__.py create mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index db8626984..8008852ea 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -383,6 +383,139 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back +def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) + + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + + return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + +class FlaxWanAttention(nn.module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + attention_kernel: str = "dot_product" + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) + precision: jax.lax.Precision = None + qkv_bias: bool = False + + def setup(self): + if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + inner_dim = self.dim_head * self.heads + scale = self.dim_head**-0.5 + + self.attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=scale, + heads=self.heads, + dim_head=self.dim_head, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype, + float32_qk_product=False, + ) + + kernel_axes = ("embed", "heads") + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) + + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "heads")) + + self.query = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_q", + precision=self.precision, + ) + + self.key = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_k", + precision=self.precision, + ) + + self.value = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_v", + precision=self.precision, + ) + + self.query_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + self.key_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + + self.proj_attn = nn.Dense( + self.query_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("heads", "embed")), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_out_0", + precision=self.precision, + ) + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def call( + self, + hidden_states: Array, + encoder_hidden_states: Optional[Array], + rotary_emb: Optional[Array], + deterministic: bool = True + ): + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_proj = self.query(hidden_states) + key_proj = self.key(encoder_hidden_states) + value_proj = self.value(encoder_hidden_states) + + query_proj = self.query_norm(query_proj) + key_proj = self.key_norm(key_proj) + + if rotary_emb: + query_proj, key_proj = self.apply_rope(query_proj, key_proj, rotary_emb) + + query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) + key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) + value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + + hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + + hidden_states = self.proj_attn(hidden_states) + hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxFluxAttention(nn.Module): query_dim: int @@ -493,15 +626,6 @@ def setup(self): param_dtype=self.weights_dtype, ) - def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: - xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) - - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - - return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) - def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): qkv_proj = self.qkv(hidden_states) @@ -535,7 +659,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) - query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb) + query_proj, key_proj = apply_rope(query_proj, key_proj, image_rotary_emb) query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 42ca4b950..3377ef6cc 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -56,7 +56,6 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal - class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. @@ -91,7 +90,8 @@ class FlaxTimesteps(nn.Module): dim: int = 32 flip_sin_to_cos: bool = False - freq_shift: float = 1 + freq_shift: float = 1.0 + scale: int = 1 @nn.compact def __call__(self, timesteps): diff --git a/src/maxdiffusion/models/wan/__init__.py b/src/maxdiffusion/models/wan/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/wan/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py new file mode 100644 index 000000000..968798c5d --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py @@ -0,0 +1,305 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from typing import Tuple, Dict, Optional, Any, Union +import jax +import math +import jax.numpy as jnp +from chex import Array +import flax +import flax.linen as nn + +from ...attention_flax import FlaxFeedForward, Fla +from ...embeddings_flax import ( + get_1d_rotary_pos_embed, + FlaxTimesteps, + FlaxTimestepEmbedding, + PixArtAlphaTextProjection +) + +from ....configuration_utils import ConfigMixin, flax_register_to_config +from ...modeling_flax_utils import FlaxModelMixin + +class WanRotaryPosEmbed(nn.Module): + attention_head_dim: int + patch_size: Tuple[int, int, int] + theta: float = 10000.0 + max_seq_len: int + + @nn.compact + def __call__(self, hidden_states: Array) -> Array: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + h_dim = w_dim = 2 * (self.attention_head_dim // 6) + t_dim = self.attention_head_dim - h_dim - w_dim + + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed(dim, self.max_seq_length, self.theta, freqs_dtype=jnp.float64) + freqs.append(freq) + self.freqs = jnp.concatenate(freqs, dim=1) + + sizes = [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6 + ] + cumulative_sizes = jnp.cumsum(jnp.array(sizes)) + split_indices = cumulative_sizes[:-1] + freqs_split = jnp.split(freqs, split_indices, axis=1) + + freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) + freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) + + freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) + freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) + + freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) + freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) + + freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) + freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) + + return freqs_final + +class WanImageEmbeddings(nn.Module): + out_features: int + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, encoder_hidden_states_image: Array) -> Array: + hidden_states = nn.LayerNorm( + dtype=jnp.float32, + param_dtype=jnp.float32, + )(encoder_hidden_states_image) + hidden_states = FlaxFeedForward( + self.out_features, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(hidden_states) + hidden_states = nn.LayerNorm( + dtype=jnp.float32, + param_dtype=jnp.float32, + )(hidden_states) + return hidden_states + + +class WanTimeTextImageEmbeddings(nn.Module): + dim: int + time_freq_dim: int + time_proj_dim: int + text_embed_dim: int + image_embed_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, timestep: Array, encoder_hidden_states: Array, encoder_hidden_states_image: Array) -> Array: + + timestep = FlaxTimesteps( + dim=self.time_freq_dim, + flip_sin_to_cos=True, + freq_shift=0, + )(timestep) + temb = FlaxTimestepEmbedding( + time_embed_dim=self.dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype + )(timestep) + timestep_proj = nn.Dense( + self.time_proj_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + )(nn.silu(temb)) + encoder_hidden_states = PixArtAlphaTextProjection( + hidden_size=self.dim, + act_fn="gelu_tanh", + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(encoder_hidden_states) + + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = WanImageEmbeddings( + out_features=self.dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + +class WanTransformerBlock(nn.Module): + dim: int + ffn_dim: int + num_heads: int + qk_norm: str = "rms_norm_across_heads" + cross_attn_norm: bool = False + eps: float = 1e-6 + added_kv_proj_dim: Optional[int] = None + + @nn.compact + def __call__( + self, + hidden_states: Array, + encoder_hidden_states: Array, + temb: Array, + rotary_emb: Array + ): + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + ) + + # 1. Self-attention + norm_hidden_states = (nn.LayerNorm( + epsilon=self.eps, + use_bias=False, + use_scale=False, + dtype=jnp.float32, + param_dtype=jnp.float32, + )(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) + attn_output = FlaxWanAttention( + query_dim=self.dim, + heads=self.num_heads, + dim_head=self.dim // self.num_heads, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + flash_block_sizes=self.flash_block_sizes, + + ) + +class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data used in the Wan model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + patch_size: Tuple[int] = (1, 2, 2) + num_attention_heads: int = 40 + attention_head_dim: int = 128 + in_channels: int = 16 + out_channels: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 13824 + num_layers: int = 40 + cross_attn_norm: bool = True + qk_norm: Optional[str] = "rms_norm_across_heads" + eps: float = 1e-6 + image_dim: Optional[int] = None + added_kv_proj_dim: Optional[int] = None + rope_max_seq_len: int = 1024 + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + attention: str = "dot_product" + + @nn.compact + def __call__( + self, + hidden_states: Array, + timestep: Array, + encoder_hidden_states: Array, + encoder_hidden_states_image: Optional[Array] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None + ) -> Union[Array, Dict[str, Array]]: + + inner_dim = self.num_attention_heads * self.attention_head_dim + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. Patch & position embedding + rotary_emb = WanRotaryPosEmbed( + attention_head_dim=self.attention_head_dim, + patch_size=self.patch_size, + max_seq_len=self.rope_max_seq_len + )(hidden_states) + hidden_states = nn.Conv( + features=inner_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + )(hidden_states) + flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? + flattened = hidden_states.reshape(flattened_shape) + transposed = jnp.transpose(flattened, (0, 2, 1)) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = WanTimeTextImageEmbeddings( + dim=inner_dim, + time_freq_dim=self.freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=self.text_dim, + image_embed_dim=self.image_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(timestep, encoder_hidden_states, encoder_hidden_states_image) + From 9b63238cffeced18a4b7d1836f4f093bff564d2d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 16 Apr 2025 15:06:42 +0000 Subject: [PATCH 02/25] adding nnx - wip --- src/maxdiffusion/configs/base_wan_t2v.yml | 271 ++++++++++++++++++ src/maxdiffusion/generate_wan.py | 33 +++ src/maxdiffusion/models/attention_flax.py | 4 +- .../wan/transformers/transformer_flux_wan.py | 1 - .../transformers/transformer_flux_wan_nnx.py | 70 +++++ 5 files changed, 376 insertions(+), 3 deletions(-) create mode 100644 src/maxdiffusion/configs/base_wan_t2v.yml create mode 100644 src/maxdiffusion/generate_wan.py create mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml new file mode 100644 index 000000000..944153d64 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -0,0 +1,271 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' +clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' +t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' + +# Flux params +flux_name: "flux-dev" +max_sequence_length: 512 +time_shift: True +base_shift: 0.5 +max_shift: 1.15 +# offloads t5 encoder after text encoding to save memory. +offload_encoders: True + + +unet_checkpoint: '' +revision: 'refs/pr/95' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te + +flash_block_sizes: {} +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +# flash_block_sizes: { +# "block_q" : 1536, +# "block_kv_compute" : 1536, +# "block_kv" : 1536, +# "block_q_dkv" : 1536, +# "block_kv_dkv" : 1536, +# "block_kv_dkv_compute" : 1536, +# "block_q_dq" : 1536, +# "block_kv_dq" : 1536 +# } +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_unet: False + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 4.e-7 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 200 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1 + +warmup_steps_fraction: 0.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 1.e-2 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A magical castle in the middle of a forest, artistic drawing" +prompt_2: "A magical castle in the middle of a forest, artistic drawing" +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 3.5 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 50 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py new file mode 100644 index 000000000..599fbc189 --- /dev/null +++ b/src/maxdiffusion/generate_wan.py @@ -0,0 +1,33 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Callable, List, Union, Sequence +from flax import nnx +from absl import app +from maxdiffusion import pyconfig, max_logging +from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel + +def run(config): + max_logging.log("Wan 2.1 inference script") + + wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b22ed60e9..fc838c9db 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -14,7 +14,7 @@ import functools import math - +from typing import Optional import flax.linen as nn import jax import jax.numpy as jnp @@ -414,7 +414,7 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) -class FlaxWanAttention(nn.module): +class FlaxWanAttention(nn.Module): query_dim: int heads: int = 8 dim_head: int = 64 diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py index 968798c5d..cdcecf80b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py @@ -19,7 +19,6 @@ import math import jax.numpy as jnp from chex import Array -import flax import flax.linen as nn from ...attention_flax import FlaxFeedForward, Fla diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py new file mode 100644 index 000000000..e5c1b7145 --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py @@ -0,0 +1,70 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import jax +import jax.numpy as jnp +from flax import nnx +from .... import common_types, max_logging +from ...modeling_flax_utils import FlaxModelMixin +from ....configuration_utils import ConfigMixin + +BlockSizes = common_types.BlockSizes + +class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( + self, + rngs: nnx.Rngs, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2038, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.path_embedding = nnx.Conv( + in_dim, + dim, + kernel_size=patch_size, + strides=patch_size, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ("batch",) + ), + rngs=rngs + ) + + def __call__(self, x): + x = self.path_embedding(x) + return x From 9276f26180fbc8e579b9dbf38f02d78a7c8273c8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 18 Apr 2025 23:06:25 +0000 Subject: [PATCH 03/25] wan pipeline wip --- README.md | 8 ++ src/maxdiffusion/configs/base_wan_t2v.yml | 4 +- src/maxdiffusion/generate_wan.py | 11 +- src/maxdiffusion/image_processor.py | 46 +++++++ .../models/wan/autoencoder_kl_wan.py | 87 ++++++++++++++ src/maxdiffusion/pipelines/wan/__init__.py | 0 .../pipelines/wan/pipeline_wan.py | 84 +++++++++++++ src/maxdiffusion/schedulers/__init__.py | 2 + ...heduling_flow_match_euler_discrete_flax.py | 71 +++++++++++ src/maxdiffusion/video_processor.py | 113 ++++++++++++++++++ 10 files changed, 422 insertions(+), 4 deletions(-) create mode 100644 src/maxdiffusion/models/wan/autoencoder_kl_wan.py create mode 100644 src/maxdiffusion/pipelines/wan/__init__.py create mode 100644 src/maxdiffusion/pipelines/wan/pipeline_wan.py create mode 100644 src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py create mode 100644 src/maxdiffusion/video_processor.py diff --git a/README.md b/README.md index 081d65e8e..2dc6523c4 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ MaxDiffusion supports - [Training](#training) - [Dreambooth](#dreambooth) - [Inference](#inference) + - [Wan 2.1](#wan) - [Flux](#flux) - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) @@ -171,6 +172,13 @@ To generate images, run the following command: ```bash python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` + + ## Wan + + ```bash + python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/" + ``` + ## Flux First make sure you have permissions to access the Flux repos in Huggingface. diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml index 944153d64..28ef6e77e 100644 --- a/src/maxdiffusion/configs/base_wan_t2v.yml +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -23,9 +23,7 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 -pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' -clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' -t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' # Flux params flux_name: "flux-dev" diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 599fbc189..27839954b 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -19,11 +19,20 @@ from absl import app from maxdiffusion import pyconfig, max_logging from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel +from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline def run(config): max_logging.log("Wan 2.1 inference script") - wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + pipeline, params = WanPipeline.from_pretrained( + config.pretrained_model_name_or_path, + vae=None, + transformer=None + ) + breakpoint() + + #wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 788e1c94e..2f99d0f7f 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -35,6 +35,52 @@ List[torch.FloatTensor], ] +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + class VaeImageProcessor(ConfigMixin): """ diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py new file mode 100644 index 000000000..4dc878ce4 --- /dev/null +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -0,0 +1,87 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Tuple, List +from flax import nnx +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin + +class WanEncoder3d(nnx.Module): + pass + +class WanCausalConv3d(nnx.Module): + pass + +class WanDecoder3d(nnx.Module): + pass + +class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1,2,4,4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temporal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + ): + self.z_dim = z_dim + self.temporal_downsample = temporal_downsample + self.temporal_upsample = temporal_downsample[::-1] + + self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout + ) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py new file mode 100644 index 000000000..df4cbb748 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/pipeline_wan.py @@ -0,0 +1,84 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Union, List +from transformers import AutoTokenizer, UMT5EncoderModel +import torch +from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ...video_processor import VideoProcessor +from ...schedulers import FlowMatchEulerDiscreteScheduler + +class WanPipeline(FlaxDiffusionPipeline): + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + + def _get_t5_prompt_embds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state + # prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds \ No newline at end of file diff --git a/src/maxdiffusion/schedulers/__init__.py b/src/maxdiffusion/schedulers/__init__.py index b630948e0..edd249de1 100644 --- a/src/maxdiffusion/schedulers/__init__.py +++ b/src/maxdiffusion/schedulers/__init__.py @@ -48,6 +48,7 @@ _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"] + _import_structure["scheduling_flow_match_euler_discrete_flax"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_utils_flax"] = [ "FlaxKarrasDiffusionSchedulers", "FlaxSchedulerMixin", @@ -73,6 +74,7 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_flow_match_euler_discrete_flax import FlowMatchEulerDiscreteScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py new file mode 100644 index 000000000..db96f53fd --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py @@ -0,0 +1,71 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + # FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + +@flax.struct.dataclass +class FlowMatchEulerDiscreteSchedulerState: + common: CommonSchedulerState + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): + state: FlowMatchEulerDiscreteSchedulerState + +class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + dtype: jnp.dtype = jnp.float32 + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) \ No newline at end of file diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py new file mode 100644 index 000000000..2da782b46 --- /dev/null +++ b/src/maxdiffusion/video_processor.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) + + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs From 120ceb3b98e229cc0adc7963eca28abeaf33aea0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 22 Apr 2025 18:26:04 +0000 Subject: [PATCH 04/25] wip - vae --- README.md | 8 + src/maxdiffusion/configs/base_wan_t2v.yml | 4 +- src/maxdiffusion/generate_wan.py | 191 ++++++++++++++++- src/maxdiffusion/image_processor.py | 46 +++++ .../models/wan/autoencoder_kl_wan.py | 193 ++++++++++++++++++ src/maxdiffusion/pipelines/wan/__init__.py | 0 .../pipelines/wan/pipeline_wan.py | 84 ++++++++ src/maxdiffusion/schedulers/__init__.py | 2 + ...heduling_flow_match_euler_discrete_flax.py | 71 +++++++ src/maxdiffusion/tests/wan_vae_test.py | 30 +++ src/maxdiffusion/video_processor.py | 113 ++++++++++ 11 files changed, 737 insertions(+), 5 deletions(-) create mode 100644 src/maxdiffusion/models/wan/autoencoder_kl_wan.py create mode 100644 src/maxdiffusion/pipelines/wan/__init__.py create mode 100644 src/maxdiffusion/pipelines/wan/pipeline_wan.py create mode 100644 src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py create mode 100644 src/maxdiffusion/tests/wan_vae_test.py create mode 100644 src/maxdiffusion/video_processor.py diff --git a/README.md b/README.md index 081d65e8e..2dc6523c4 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ MaxDiffusion supports - [Training](#training) - [Dreambooth](#dreambooth) - [Inference](#inference) + - [Wan 2.1](#wan) - [Flux](#flux) - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) @@ -171,6 +172,13 @@ To generate images, run the following command: ```bash python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` + + ## Wan + + ```bash + python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/" + ``` + ## Flux First make sure you have permissions to access the Flux repos in Huggingface. diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml index 944153d64..28ef6e77e 100644 --- a/src/maxdiffusion/configs/base_wan_t2v.yml +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -23,9 +23,7 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 -pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' -clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' -t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' # Flux params flux_name: "flux-dev" diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 599fbc189..6945a62f5 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -13,18 +13,205 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from typing import Callable, List, Union, Sequence +import html +from typing import Callable, List, Union, Sequence, Optional +import time +import torch +import ftfy +import regex as re +import jax +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P from flax import nnx from absl import app +from transformers import AutoTokenizer, UMT5EncoderModel from maxdiffusion import pyconfig, max_logging from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel +from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline + +from maxdiffusion.max_utils import ( + device_put_replicated, + get_memory_allocations, + create_device_mesh, + get_flash_block_sizes, + get_precision, + setup_initial_state, +) + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + +def _get_t5_prompt_embeds( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + +def encode_prompt( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds def run(config): max_logging.log("Wan 2.1 inference script") + rng = jax.random.key(config.seed) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + global_batch_size = config.per_device_batch_size * jax.local_device_count() + + tokenizer = AutoTokenizer.from_pretrained( + config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype + ) + text_encoder = UMT5EncoderModel.from_pretrained( + config.pretrained_model_name_or_path, subfolder="text_encoder", + ) + s0 = time.perf_counter() + prompt_embeds, negative_prompt_embeds = encode_prompt( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=config.prompt, + negative_prompt=config.negative_prompt + ) + max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") + + # pipeline, params = WanPipeline.from_pretrained( + # config.pretrained_model_name_or_path, + # #vae=None, + # #transformer=None + # ) + # breakpoint() + wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 788e1c94e..2f99d0f7f 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -35,6 +35,52 @@ List[torch.FloatTensor], ] +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + class VaeImageProcessor(ConfigMixin): """ diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py new file mode 100644 index 000000000..65da97bc9 --- /dev/null +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -0,0 +1,193 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Tuple, List, Sequence, Union, Optional + +import jax +import jax.numpy as jnp +from flax import nnx +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin +from ... import common_types + +BlockSizes = common_types.BlockSizes + +_ACTIVATIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + "relu": jax.nn.relu, + "gelu": jax.nn.gelu, + "mish": jax.nn.mish +} + +def get_activation(name: str): + func = _ACTIVATIONS.get(name) + if func is None: + raise ValueError(f"Unknown activation function: {name}") + return func + +# Helper to ensure kernel_size, stride, padding are tuples of 3 integers +def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: + """Canonicalizes a value to a tuple of integers.""" + if isinstance(x, int): + return (x,) * rank + elif isinstance(x, Sequence) and len(x) == rank: + return tuple(x) + else: + raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") + +class WanCausalConv3d(nnx.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + *, # Mark subsequent arguments as keyword-only + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + use_bias: bool = True, + rngs: nnx.Rngs, # rngs are required for initializing parameters, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') + self.stride = _canonicalize_tuple(stride, 3, 'stride') + padding_tuple = _canonicalize_tuple(padding, 3, 'padding') # (D, H, W) padding amounts + + self._causal_padding = ( + (0, 0), # Batch dimension - no padding + (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) + (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding + (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding + (0, 0) # Channel dimension - no padding + ) + + # Store the amount of padding needed *before* the depth dimension for caching logoic + self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + + self.conv = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + use_bias=use_bias, + padding='VALID', # Handle padding manually + rngs=rngs + ) + + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Array: + current_padding = list(self._causal_padding) # Mutable copy + padding_needed = self._depth_padding_before + + if cache_x is not None and padding_needed > 0: + # Ensure cache has same spatial/channel dims, potentially different depth + assert cache_x.shape[0] == x.shape[0] and \ + cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" + + cache_len = cache_x.shape[1] + x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) + + padding_needed -= cache_len + if padding_needed < 0: + # Cache longer than needed padding, trim from start + x = x[:, -padding_needed:, ...] + current_padding[1] = (0, 0) # No explicit padding needed now + else: + # Update depth padding needed + current_padding[1] = (padding_needed, 0) + + # Apply padding if any dimension requires it + padding_to_apply = tuple(current_padding) + if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): + x_padded = jnp.pad(x, padding_to_apply, mode='constant', constant_values=0.0) + else: + x_padded = x + + out = self.conv(x_padded) + return out + +class WanRMS_norm(nnx.Module): + def __init__( + self, + dim: int, + *, + eps: float = 1e-6, + use_bias: bool = False, + rngs: nnx.Rngs + ): + self.eps = eps + shape = (dim,) + self.scale = dim ** 0.5 + # Initialize gamma as parameter + self.gamma = nnx.Param(jax.random.ones(rngs.params(), shape)) + if use_bias: + self.bias = nnx.Param(jax.random.zeros(rngs.params(), shape)) + else: + self.bias = None + + def __call__(self, x: jax.Array) -> jax.Array: + # Expects input channels in the last dimension + variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + inv_std = jax.lax.rsqrt(variance + self.eps) + normalized = x * inv_std * self.gamma.value * self.scale + if self.bias: + return normalized + self.bias.value + return normalized + + +class WanEncoder3d(nnx.Module): + pass + +class WanCausalConv3d(nnx.Module): + pass + +class WanDecoder3d(nnx.Module): + pass + +class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1,2,4,4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temporal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571,-0.7089,-0.9113,0.1075,-0.1745,0.9653,-0.1517, 1.5508, + 0.4134,-0.0715,0.5517,-0.3632,-0.1922,-0.9497,0.2503,-0.2921, + ], + latents_std: List[float] = [ + 2.8184,1.4541,2.3275,2.6558,1.2196,1.7708,2.6052,2.0743, + 3.2687,2.1526,2.8652,1.5579,1.6382,1.1253,2.8251,1.9160, + ], + ): + self.z_dim = z_dim + self.temporal_downsample = temporal_downsample + self.temporal_upsample = temporal_downsample[::-1] + + self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout + ) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py new file mode 100644 index 000000000..6bd9fb09b --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/pipeline_wan.py @@ -0,0 +1,84 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Union, List +from transformers import AutoTokenizer, UMT5EncoderModel +import torch +from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ...video_processor import VideoProcessor +#from ...schedulers import FlowMatchEulerDiscreteScheduler + +class WanPipeline(FlaxDiffusionPipeline): + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanModel, + vae: AutoencoderKLWan, + #scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + #scheduler=scheduler + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + + def _get_t5_prompt_embds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state + # prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds \ No newline at end of file diff --git a/src/maxdiffusion/schedulers/__init__.py b/src/maxdiffusion/schedulers/__init__.py index b630948e0..edd249de1 100644 --- a/src/maxdiffusion/schedulers/__init__.py +++ b/src/maxdiffusion/schedulers/__init__.py @@ -48,6 +48,7 @@ _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"] + _import_structure["scheduling_flow_match_euler_discrete_flax"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_utils_flax"] = [ "FlaxKarrasDiffusionSchedulers", "FlaxSchedulerMixin", @@ -73,6 +74,7 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_flow_match_euler_discrete_flax import FlowMatchEulerDiscreteScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py new file mode 100644 index 000000000..db96f53fd --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py @@ -0,0 +1,71 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + # FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + +@flax.struct.dataclass +class FlowMatchEulerDiscreteSchedulerState: + common: CommonSchedulerState + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): + state: FlowMatchEulerDiscreteSchedulerState + +class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + dtype: jnp.dtype = jnp.float32 + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py new file mode 100644 index 000000000..d82a129e0 --- /dev/null +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -0,0 +1,30 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import unittest +import pytest +from absl.testing import absltest + +class WanVaeTest(unittest.TestCase): + def setUp(self): + WanVaeTest.dummy_data = {} + + # def test_3d_conv(self): + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py new file mode 100644 index 000000000..2da782b46 --- /dev/null +++ b/src/maxdiffusion/video_processor.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) + + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs From cc11bb155f92e54cbb52bb7b39814c2fd7614953 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Apr 2025 20:11:40 +0000 Subject: [PATCH 05/25] added tests for a couple of wan vae layers. --- .../models/wan/autoencoder_kl_wan.py | 369 +++++++++++++++++- src/maxdiffusion/tests/wan_vae_test.py | 134 ++++++- 2 files changed, 485 insertions(+), 18 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 65da97bc9..9d0f89621 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -22,9 +22,12 @@ from ...configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ... import common_types +from ..vae_flax import FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution BlockSizes = common_types.BlockSizes +CACHE_T = 2 + _ACTIVATIONS = { "swish": jax.nn.silu, "silu": jax.nn.silu, @@ -128,43 +131,350 @@ class WanRMS_norm(nnx.Module): def __init__( self, dim: int, - *, + rngs: nnx.Rngs, + channel_first: bool = True, + images: bool = True, eps: float = 1e-6, use_bias: bool = False, - rngs: nnx.Rngs ): + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim, ) self.eps = eps - shape = (dim,) + self.channel_first = channel_first self.scale = dim ** 0.5 # Initialize gamma as parameter - self.gamma = nnx.Param(jax.random.ones(rngs.params(), shape)) + self.gamma = nnx.Param(jnp.ones(shape)) if use_bias: - self.bias = nnx.Param(jax.random.zeros(rngs.params(), shape)) + self.bias = nnx.Param(jnp.zeros(shape)) else: - self.bias = None + self.bias = 0 def __call__(self, x: jax.Array) -> jax.Array: - # Expects input channels in the last dimension - variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) - inv_std = jax.lax.rsqrt(variance + self.eps) - normalized = x * inv_std * self.gamma.value * self.scale + normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True) + normalized = x / jnp.maximum(normalized, self.eps) + normalized = normalized * self.scale * self.gamma if self.bias: return normalized + self.bias.value return normalized +class WanUpsample(nnx.Module): + def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'): + # scale_factor for (H, W) + # JAX resize works on spatial dims, H, W assumming (N, D, H, W, C) or (N, H, W, C) + self.scale_factor = scale_factor + self.method = method + + def __call__(self, x: jax.Array) -> jax.Array: + input_dtype = x.dtype + in_shape = x.shape + is_3d = len(in_shape) == 5 + n, d, h, w, c = in_shape if is_3d else(in_shape[0], 1, in_shape[1], in_shape[2], in_shape[3]) + + target_h = int(h * self.scale_factor[0]) + target_w = int(w * self.scale_factor[1]) + + # jax.image.resize expects (..., H, W, C) + if is_3d: + x_reshaped = x.reshape(n * d, h, w, c) + out_reshaped = jax.image.resize(x_reshaped.astype(jnp.float32), + (n * d, target_h, target_w, c), + method=self.method) + out = out_reshaped.reshape(n, d, target_h, target_w, c) + else: # Asumming (N, H, W, C) + out = jax.image.resize(x.astype(jnp.float32), + (n, target_h, target_w, c), + method=self.method) + + return out.astype(input_dtype) + +class Identity(nnx.Module): + def __call__(self, x): + return x + +class ZeroPaddedConv2D(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.conv = nnx.Conv( + dim, + dim, + kernel_size=(1, 3, 3), + padding='SAME', + use_bias=True, + rngs=rngs + ) + + def __call__(self, x): + # This pad assumes (B, C, H, W) + x = jax.lax.pad(x, 0.0, [(0, 0, 0), (0, 0, 0), (0, 1, 0), (0, 1, 0)]) + return self.conv(x) + + +class WanResample(nnx.Module): + def __init__( + self, + dim: int, + mode: str, + rngs: nnx.Rngs, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.dim = dim + self.mode = mode + self.time_conv = None + + if mode == "upsample2d": + self.resample = nnx.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding='SAME', + use_bias=True, + rngs=rngs, + ) + ) + elif mode == "upsample3d": + self.resample = nnx.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding='SAME', + use_bias=True, + rngs=rngs, + ) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs) + elif mode == "downsample2d": + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + elif mode == "downsample3d": + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + else: + self.resample = Identity() + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: + return x + +class WanResidualBlock(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + non_linearity: str = "silu", + ): + pass + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x + +class WanAttentionBlock(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs + ): + self.dim = dim + + def __call__(self, x: jax.Array): + return x + +class WanMidBlock(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1 + ): + self.dim = dim + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x + +class WanUpBlock(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu" + ): + # Create layers list + self.resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + self.resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) + current_dim = out_dim + + # Add upsampling layer if needed. + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x class WanEncoder3d(nnx.Module): - pass + def __init__( + self, + rngs: nnx.Rngs, + dim: int =128, + z_dim: int = 4, + dim_mult = [1, 2, 4, 4], + num_res_blocks = 2, + attn_scales = [], + temperal_downsample = [True, True, False], + dropout = 0.0, + non_linearity: str = 'silu', + ): + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) -class WanCausalConv3d(nnx.Module): - pass + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1, rngs=rngs) + + # downsample blocks + self.down_blocks = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim=in_dim, out_dim=out_dim, dropout=dropout, rngs=rngs)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(dim=out_dim, rngs=rngs)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) + + # middle_blocks + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1, rngs=rngs) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs) + self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x class WanDecoder3d(nnx.Module): - pass + r""" + A 3D decoder module. + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + def __init__( + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 128, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales = List[float], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = WanCausalConv3d(in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1, rngs=rngs) + + # middle_blocks + self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Crete and add the upsampling block + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + rngs=rngs + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *=2.0 + + # output blocks + self.norm_out = nnx.RMSNorm(num_features=out_dim, ) + self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs) + self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): def __init__( self, + rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult: Tuple[int] = [1,2,4,4], @@ -186,8 +496,35 @@ def __init__( self.temporal_upsample = temporal_downsample[::-1] self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) - self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs) self.decoder = WanDecoder3d( base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout - ) \ No newline at end of file + ) + self.clear_cache() + + def clear_cache(self): + """ Resets cache dictionaries and indices""" + def _count_conv3d(module): + count = 0 + node_types = nnx.graph.iter_graph(module, nnx.Module) + for node in node_types: + if isinstance(node.value, WanCausalConv3d): + count +=1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + + def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: + """ Encode video into latent distribution.""" + if x.shape[-1] != 3: + raise ValueError(f"Expected input shape (N, D, H, W, 3), got {x.shape}") + + self.clear_cache() \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index d82a129e0..a65f0c425 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -15,16 +15,146 @@ """ import os +import jax +import jax.numpy as jnp +from flax import nnx +import numpy as np import unittest import pytest from absl.testing import absltest +from ..models.wan.autoencoder_kl_wan import WanCausalConv3d, WanUpsample, AutoencoderKLWan, WanRMS_norm class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} - # def test_3d_conv(self): + # def test_clear_cache(self): + # key = jax.random.key(0) + # rngs = nnx.Rngs(key) + # wan_vae = AutoencoderKLWan(rngs=rngs) + # wan_vae.clear_cache() + + def test_wanrms_norm(self): + """Test against the Pytorch implementation""" + import torch + import torch.nn as nn + import torch.nn.functional as F + + class TorchWanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + # --- Test Case 1: images == True --- + model = TorchWanRMS_norm(2) + input_shape = (1, 2, 2, 2, 3) + input = torch.ones(input_shape) + torch_output = model(input) + torch_output_np = torch_output.detach().numpy() + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wanrms_norm = WanRMS_norm(dim=2, rngs=rngs) + input_shape = (1, 2, 2, 2, 3) + dummy_input = jnp.ones(input_shape) + output = wanrms_norm(dummy_input) + output_np = np.array(output) + assert np.allclose(output_np, torch_output_np) == True + + # --- Test Case 2: images == False --- + model = TorchWanRMS_norm(2, images=False) + input_shape = (1, 2, 2, 2, 3) + input = torch.ones(input_shape) + torch_output = model(input) + torch_output_np = torch_output.detach().numpy() + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wanrms_norm = WanRMS_norm(dim=2, rngs=rngs, images=False) + input_shape = (1, 2, 2, 2, 3) + dummy_input = jnp.ones(input_shape) + output = wanrms_norm(dummy_input) + output_np = np.array(output) + assert np.allclose(output_np, torch_output_np) == True + + def test_wan_upsample(self): + batch_size=1 + in_depth, in_height, in_width = 10, 32, 32 + in_channels = 3 + + dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + + upsample = WanUpsample(scale_factor=(2.0, 2.0)) + + # --- Test Case 1: depth > 1 --- + output = upsample(dummy_input) + assert output.shape == (1, 10, 64, 64, 3) + + in_depth = 1 + dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + # --- Test Case 1: depth == 1 --- + output = upsample(dummy_input) + assert output.shape == (1, 1, 64, 64, 3) + + def test_3d_conv(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size=1 + in_depth, in_height, in_width = 10, 32, 32 + in_channels = 3 + out_channels = 16 + kernel_d, kernel_h, kernel_w = 3, 3, 3 # Kernel size (Depth, Height, Width) + padding_d, padding_h, padding_w = 1, 1, 1 # Base padding (Depth, Height, Width) + + # Create dummy input data + dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + + # Create dummy cache data (from a previous step) + cache_depth = 2 * padding_d + dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) + + # Instantiate the module + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs # Pass rngs for initialization + ) + + # --- Test Case 1: No Cache --- + output_no_cache = causal_conv_layer(dummy_input) + assert output_no_cache.shape == (1, 10, 32, 32, 16) + + # --- Test Case 2: With Cache --- + output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) + assert output_with_cache.shape == (1, 10, 32, 32, 16) + + # --- Test Case 3: With Cache larger than padding --- + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) + output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) + assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) - if __name__ == "__main__": absltest.main() \ No newline at end of file From 4e443b82e58a0a4adca9678209ac91927b6bcbb1 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Apr 2025 22:18:31 +0000 Subject: [PATCH 06/25] add unit tests to wan vae padded conv. --- src/maxdiffusion/generate_wan.py | 1 - .../models/wan/autoencoder_kl_wan.py | 30 ++++++++--- src/maxdiffusion/tests/wan_vae_test.py | 51 +++++++++++++++---- 3 files changed, 66 insertions(+), 16 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index c35aafecc..e19d783de 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -214,7 +214,6 @@ def run(config): vae=None, transformer=None ) - breakpoint() #wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 9d0f89621..61473ca5b 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -192,10 +192,16 @@ def __call__(self, x): return x class ZeroPaddedConv2D(nnx.Module): + """ + Module for adding padding before conv. + Currently it does not add any padding. + """ def __init__( self, dim: int, rngs: nnx.Rngs, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, @@ -204,18 +210,18 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", ): + kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') + stride = _canonicalize_tuple(stride, 3, 'stride') self.conv = nnx.Conv( dim, dim, - kernel_size=(1, 3, 3), - padding='SAME', + kernel_size=kernel_size, + strides=stride, use_bias=True, rngs=rngs ) def __call__(self, x): - # This pad assumes (B, C, H, W) - x = jax.lax.pad(x, 0.0, [(0, 0, 0), (0, 0, 0), (0, 1, 0), (0, 1, 0)]) return self.conv(x) @@ -263,9 +269,21 @@ def __init__( ) self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs) elif mode == "downsample2d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + # TODO - do I need to transpose? + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(1, 3, 3), + stride=(1, 2, 2) + ) elif mode == "downsample3d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + # TODO - do I need to transpose? + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(1, 3, 3), + stride=(1, 2, 2) + ) else: self.resample = Identity() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index a65f0c425..3aa64a133 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -22,7 +22,13 @@ import unittest import pytest from absl.testing import absltest -from ..models.wan.autoencoder_kl_wan import WanCausalConv3d, WanUpsample, AutoencoderKLWan, WanRMS_norm +from ..models.wan.autoencoder_kl_wan import ( + WanCausalConv3d, + WanUpsample, + AutoencoderKLWan, + WanRMS_norm, + ZeroPaddedConv2D +) class WanVaeTest(unittest.TestCase): def setUp(self): @@ -66,36 +72,63 @@ def forward(self, x): return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias # --- Test Case 1: images == True --- - model = TorchWanRMS_norm(2) - input_shape = (1, 2, 2, 2, 3) + dim = 96 + input_shape = (1, 96, 1, 480, 720) + + model = TorchWanRMS_norm(dim) input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() key = jax.random.key(0) rngs = nnx.Rngs(key) - wanrms_norm = WanRMS_norm(dim=2, rngs=rngs) - input_shape = (1, 2, 2, 2, 3) + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) assert np.allclose(output_np, torch_output_np) == True # --- Test Case 2: images == False --- - model = TorchWanRMS_norm(2, images=False) - input_shape = (1, 2, 2, 2, 3) + model = TorchWanRMS_norm(dim, images=False) input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() key = jax.random.key(0) rngs = nnx.Rngs(key) - wanrms_norm = WanRMS_norm(dim=2, rngs=rngs, images=False) - input_shape = (1, 2, 2, 2, 3) + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs, images=False) dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) assert np.allclose(output_np, torch_output_np) == True + + def test_zero_padded_conv(self): + import torch + import torch.nn as nn + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + dim = 96 + kernel_size = 3 + stride= (2, 2) + resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride)) + input_shape = (1, 96, 480, 720) + input = torch.ones(input_shape) + output_torch = resample(input) + assert output_torch.shape == (1, 96, 240, 360) + + model = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(1, 3, 3), + stride=(1, 2, 2) + ) + dummy_input = jnp.ones(input_shape) + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) + output = model(dummy_input) + output = jnp.transpose(output, (0, 3, 1, 2)) + assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): batch_size=1 From aeabe27e586cc1288e92229c5cdb34635301ff1e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 17:55:19 +0000 Subject: [PATCH 07/25] wip - test for vae encoder. --- .../models/wan/autoencoder_kl_wan.py | 88 +++++++- src/maxdiffusion/tests/wan_vae_test.py | 196 ++++++++++++++++-- 2 files changed, 256 insertions(+), 28 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 61473ca5b..faac464cc 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -55,14 +55,13 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T class WanCausalConv3d(nnx.Module): def __init__( self, + rngs: nnx.Rngs, # rngs are required for initializing parameters, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], - *, # Mark subsequent arguments as keyword-only stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, use_bias: bool = True, - rngs: nnx.Rngs, # rngs are required for initializing parameters, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, @@ -267,7 +266,13 @@ def __init__( rngs=rngs, ) ) - self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs) + self.time_conv = WanCausalConv3d( + rngs=rngs, + in_channels=dim, + out_channels=dim * 2, + kernel_size=(3, 1, 1), + padding=(1, 0, 0), + ) elif mode == "downsample2d": # TODO - do I need to transpose? self.resample = ZeroPaddedConv2D( @@ -288,6 +293,15 @@ def __init__( self.resample = Identity() def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: + # Input x: (N, D, H, W, C), assume C = self.dim + n, d, h, w, c = x.shape + assert c == self.dim + + x = x.reshape(n*d,h,w,c) + x = self.resample(x) + h_new, w_new, c_new = x.shape[1:] + x = x.reshape(n, d, h_new, w_new, c_new) + return x class WanResidualBlock(nnx.Module): @@ -382,7 +396,13 @@ def __init__( scale = 1.0 # init block - self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1, rngs=rngs) + self.conv_in = WanCausalConv3d( + in_channels=3, + out_channels=dims[0], + kernel_size=3, + padding=1, + rngs=rngs + ) # downsample blocks self.down_blocks = [] @@ -400,11 +420,23 @@ def __init__( self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) # middle_blocks - self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1, rngs=rngs) + self.mid_block = WanMidBlock( + dim=out_dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, + ) # output blocks self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs) - self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=z_dim, + kernel_size=3, + padding=1 + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): return x @@ -487,6 +519,9 @@ def __init__( self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + breakpoint() + x = self.conv_in(x) + breakpoint() return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -514,6 +549,7 @@ def __init__( self.temporal_upsample = temporal_downsample[::-1] self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) + self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1, rngs=rngs) self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs) self.decoder = WanDecoder3d( @@ -539,10 +575,42 @@ def _count_conv3d(module): self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num + def _encode(self, x: jax.Array): + if x.shape[-1] != 3: + # reshape channel last for JAX + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" + + self.clear_cache() + + t = x.shape[1] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + if i == 0: + out = self.encoder( + x[:, :1, :, :, :], + feat_cache=self._enc_feat_map, + feat_ids=self._enc_conv_idx + ) + else: + out_ = self.encoder( + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx + ) + out = jnp.concatenate([out, out_], axis=1) + + enc = self.quant_conv(out) + mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] + enc = jnp.concatenate([mu, logvar], dim=1) + self.clear_cache() + return enc def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """ Encode video into latent distribution.""" - if x.shape[-1] != 3: - raise ValueError(f"Expected input shape (N, D, H, W, 3), got {x.shape}") - - self.clear_cache() \ No newline at end of file + h = self._encode(x) + posterior = FlaxDiagonalGaussianDistribution(h) + if not return_dict: + return (posterior, ) + return FlaxAutoencoderKLOutput(latent_dict=posterior) + \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 3aa64a133..0e5df142f 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -15,6 +15,9 @@ """ import os +import torch +import torch.nn as nn +import torch.nn.functional as F import jax import jax.numpy as jnp from flax import nnx @@ -26,27 +29,15 @@ WanCausalConv3d, WanUpsample, AutoencoderKLWan, + WanEncoder3d, WanRMS_norm, + WanResample, ZeroPaddedConv2D ) -class WanVaeTest(unittest.TestCase): - def setUp(self): - WanVaeTest.dummy_data = {} - - # def test_clear_cache(self): - # key = jax.random.key(0) - # rngs = nnx.Rngs(key) - # wan_vae = AutoencoderKLWan(rngs=rngs) - # wan_vae.clear_cache() +CACHE_T = 2 - def test_wanrms_norm(self): - """Test against the Pytorch implementation""" - import torch - import torch.nn as nn - import torch.nn.functional as F - - class TorchWanRMS_norm(nn.Module): +class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer. @@ -70,6 +61,103 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi def forward(self, x): return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + +class TorchWanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + +class WanVaeTest(unittest.TestCase): + def setUp(self): + WanVaeTest.dummy_data = {} + + # def test_clear_cache(self): + # key = jax.random.key(0) + # rngs = nnx.Rngs(key) + # wan_vae = AutoencoderKLWan(rngs=rngs) + # wan_vae.clear_cache() + + def test_wanrms_norm(self): + """Test against the Pytorch implementation""" # --- Test Case 1: images == True --- dim = 96 @@ -103,8 +191,6 @@ def forward(self, x): assert np.allclose(output_np, torch_output_np) == True def test_zero_padded_conv(self): - import torch - import torch.nn as nn key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -148,6 +234,49 @@ def test_wan_upsample(self): # --- Test Case 1: depth == 1 --- output = upsample(dummy_input) assert output.shape == (1, 1, 64, 64, 3) + + def test_wan_resample(self): + # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + # --- Test Case 1: downsample2d --- + batch = 1 + dim = 96 + t = 1 + h = 480 + w = 720 + mode="downsample2d" + input_shape = (batch, dim, t, h, w) + expected_output_shape = (1, dim, 1, 240, 360) + # output dim should be (1, 96, 1, 480, 720) + dummy_input = torch.ones(input_shape) + torch_wan_resample = TorchWanResample( + dim=dim, + mode=mode + ) + torch_output = torch_wan_resample(dummy_input) + assert torch_output.shape == (batch, dim, t, h // 2, w //2) + + wan_resample = WanResample( + dim, + mode=mode, + rngs=rngs + ) + # channels is always last here + input_shape = (batch, t, h, w, dim) + dummy_input = jnp.ones(input_shape) + output = wan_resample(dummy_input) + assert output.shape == (batch, t, h//2, h//2, dim) + breakpoint() + + # --- Test Case 1: downsample3d --- + dim = 192 + input_shape = (1, dim, 1, 240, 360) + torch_wan_resample = WanResample( + dim=dim, + mode="downsample3d" + ) def test_3d_conv(self): key = jax.random.key(0) @@ -189,5 +318,36 @@ def test_3d_conv(self): output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + def test_wan_encode(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 96 + z_dim = 32 + dim_mult = [1, 2, 4, 4] + num_res_blocks = 2 + attn_scales = [] + temperal_downsample = [False, True, True] + nonlinearity = "silu" + wan_encoder = WanEncoder3d( + rngs=rngs, + dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + non_linearity=nonlinearity + ) + batch = 1 + channels = 3 + t = 49 + height = 480 + width = 720 + input_shape = (batch, channels, t, height, width) + input = jnp.ones(input_shape) + output = wan_encoder(input) + + + if __name__ == "__main__": absltest.main() \ No newline at end of file From 0ec4b02d00ec589db92e0f601bf955fad7c178bf Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 21:34:40 +0000 Subject: [PATCH 08/25] Residual block test --- .../models/wan/autoencoder_kl_wan.py | 140 ++++++++++++++---- src/maxdiffusion/tests/wan_vae_test.py | 70 +++++++-- 2 files changed, 173 insertions(+), 37 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index faac464cc..8f6e2789a 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -313,10 +313,49 @@ def __init__( dropout: float = 0.0, non_linearity: str = "silu", ): - pass + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) + self.conv1 = WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=3, + padding=1 + ) + self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) + self.dropout = nnx.Dropout(dropout, rngs=rngs) + self.conv2 = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=out_dim, + kernel_size=3, + padding=1 + ) + self.conv_shortcut = WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1 + ) if in_dim != out_dim else Identity() + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - return x + # Apply shortcut connection + #breakpoint() + h = self.conv_shortcut(x) + + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.dropout(x) + x = self.conv2(x) + + return x + h class WanAttentionBlock(nnx.Module): def __init__( @@ -397,11 +436,11 @@ def __init__( # init block self.conv_in = WanCausalConv3d( + rngs=rngs, in_channels=3, out_channels=dims[0], kernel_size=3, padding=1, - rngs=rngs ) # downsample blocks @@ -439,6 +478,12 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + # (1, 1, 480, 720, 3) + x = self.conv_in(x) + # (1, 1, 480, 720, 96) + for layer in self.down_blocks: + x = layer(x) + breakpoint() return x class WanDecoder3d(nnx.Module): @@ -480,7 +525,13 @@ def __init__( scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block - self.conv_in = WanCausalConv3d(in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1, rngs=rngs) + self.conv_in = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim, + out_channels=dims[0], + kernel_size=3, + padding=1 + ) # middle_blocks self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) @@ -516,7 +567,13 @@ def __init__( # output blocks self.norm_out = nnx.RMSNorm(num_features=out_dim, ) self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs) - self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs) + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=3, + kernel_size=3, + padding=1 + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): breakpoint() @@ -533,7 +590,7 @@ def __init__( dim_mult: Tuple[int] = [1,2,4,4], num_res_blocks: int = 2, attn_scales: List[float] = [], - temporal_downsample: List[bool] = [False, True, True], + temperal_downsample: List[bool] = [False, True, True], dropout: float = 0.0, latents_mean: List[float] = [ -0.7571,-0.7089,-0.9113,0.1075,-0.1745,0.9653,-0.1517, 1.5508, @@ -545,31 +602,59 @@ def __init__( ], ): self.z_dim = z_dim - self.temporal_downsample = temporal_downsample - self.temporal_upsample = temporal_downsample[::-1] - - self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) - self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1, rngs=rngs) - self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs) + self.temperal_downsample = temperal_downsample + self.temporal_upsample = temperal_downsample[::-1] - self.decoder = WanDecoder3d( - base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout + self.encoder = WanEncoder3d( + rngs=rngs, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + ) + self.quant_conv = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim * 2, + out_channels=z_dim * 2, + kernel_size=1 + ) + self.post_quant_conv = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim, + out_channels=z_dim, + kernel_size=1, ) + + # self.decoder = WanDecoder3d( + # rngs=rngs, + # dim=base_dim, + # z_dim=z_dim, + # dim_mult=dim_mult, + # num_res_blocks=num_res_blocks, + # attn_scales=attn_scales, + # temperal_upsample=self.temporal_upsample, + # dropout=dropout + # ) self.clear_cache() def clear_cache(self): """ Resets cache dictionaries and indices""" def _count_conv3d(module): count = 0 - node_types = nnx.graph.iter_graph(module, nnx.Module) - for node in node_types: - if isinstance(node.value, WanCausalConv3d): + node_types = nnx.graph.iter_graph([module]) + for path, value in node_types: + #breakpoint() + if isinstance(value, WanCausalConv3d): + print("value: ", value) count +=1 return count - self._conv_num = _count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num + # self._conv_num = _count_conv3d(self.decoder) + # self._conv_idx = [0] + # self._feat_map = [None] * self._conv_num # cache encode self._enc_conv_num = _count_conv3d(self.encoder) self._enc_conv_idx = [0] @@ -581,7 +666,7 @@ def _encode(self, x: jax.Array): x = jnp.transpose(x, (0, 2, 3, 4, 1)) assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" - self.clear_cache() + #self.clear_cache() t = x.shape[1] iter_ = 1 + (t - 1) // 4 @@ -590,7 +675,7 @@ def _encode(self, x: jax.Array): out = self.encoder( x[:, :1, :, :, :], feat_cache=self._enc_feat_map, - feat_ids=self._enc_conv_idx + feat_idx=self._enc_conv_idx ) else: out_ = self.encoder( @@ -600,11 +685,12 @@ def _encode(self, x: jax.Array): ) out = jnp.concatenate([out, out_], axis=1) - enc = self.quant_conv(out) - mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] - enc = jnp.concatenate([mu, logvar], dim=1) - self.clear_cache() - return enc + # enc = self.quant_conv(out) + # mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] + # enc = jnp.concatenate([mu, logvar], dim=1) + # self.clear_cache() + # return enc + return x def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """ Encode video into latent distribution.""" diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 0e5df142f..17dd7c45d 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -30,6 +30,7 @@ WanUpsample, AutoencoderKLWan, WanEncoder3d, + WanResidualBlock, WanRMS_norm, WanResample, ZeroPaddedConv2D @@ -318,6 +319,45 @@ def test_3d_conv(self): output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + def test_wan_residual(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + # one test + in_dim = out_dim = 96 + batch = 1 + t = 1 + height = 480 + width = 720 + dim = 96 + input_shape = (batch, t, height, width, dim) + expected_output_shape = (batch, t, height, width, dim) + + wan_residual_block = WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, + ) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + + # another test + in_dim = 96 + out_dim = 196 + expected_output_shape = (batch, t, height, width, out_dim) + + wan_residual_block = WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, + ) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + + + + def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -328,16 +368,25 @@ def test_wan_encode(self): attn_scales = [] temperal_downsample = [False, True, True] nonlinearity = "silu" - wan_encoder = WanEncoder3d( - rngs=rngs, - dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - non_linearity=nonlinearity + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, ) + # wan_encoder = WanEncoder3d( + # rngs=rngs, + # dim=dim, + # z_dim=z_dim, + # dim_mult=dim_mult, + # num_res_blocks=num_res_blocks, + # attn_scales=attn_scales, + # temperal_downsample=temperal_downsample, + # non_linearity=nonlinearity + # ) batch = 1 channels = 3 t = 49 @@ -345,7 +394,8 @@ def test_wan_encode(self): width = 720 input_shape = (batch, channels, t, height, width) input = jnp.ones(input_shape) - output = wan_encoder(input) + output = wan_vae.encode(input) + breakpoint() From 9b42117db99207e873b38ef2e65b9eca8ae68135 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 23:05:30 +0000 Subject: [PATCH 09/25] add wan vae attention test --- .../models/wan/autoencoder_kl_wan.py | 38 ++++++++++++++++++- src/maxdiffusion/tests/wan_vae_test.py | 25 +++++++++--- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 8f6e2789a..292264c2a 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -364,9 +364,45 @@ def __init__( rngs: nnx.Rngs ): self.dim = dim + self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) + self.to_qkv = nnx.Conv( + in_features=dim, + out_features=dim * 3, + kernel_size=1, + rngs=rngs + ) + self.proj = nnx.Conv( + in_features=dim, + out_features=dim, + kernel_size=1, + rngs=rngs + ) def __call__(self, x: jax.Array): - return x + batch_size, time, height, width, channels = x.shape + identity = x + + x = x.reshape(batch_size * time, height, width, channels) + x = self.norm(x) + + qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) + + qkv = qkv.reshape(batch_size*time, 1, channels * 3, -1) + qkv = jnp.transpose(qkv, (0, 1, 3, 2)) + q, k, v = jnp.split(qkv, 3, axis=-1) + + x = jax.nn.dot_product_attention(q, k, v) + x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) + + #output projection + x = self.proj(x) + + # Reshape back + x = x.reshape(batch_size, time, height, width, channels) + + return x + identity + + class WanMidBlock(nnx.Module): def __init__( diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 17dd7c45d..f91f3932e 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -33,7 +33,8 @@ WanResidualBlock, WanRMS_norm, WanResample, - ZeroPaddedConv2D + ZeroPaddedConv2D, + WanAttentionBlock ) CACHE_T = 2 @@ -322,7 +323,7 @@ def test_3d_conv(self): def test_wan_residual(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - # one test + # --- Test Case 1: same in/out dim --- in_dim = out_dim = 96 batch = 1 t = 1 @@ -341,7 +342,7 @@ def test_wan_residual(self): dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape - # another test + # --- Test Case 1: different in/out dim --- in_dim = 96 out_dim = 196 expected_output_shape = (batch, t, height, width, out_dim) @@ -355,8 +356,22 @@ def test_wan_residual(self): dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape - - + def test_wan_attention(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 384 + batch = 1 + t = 1 + height = 60 + width = 90 + input_shape=(batch, t, height, width, dim) + wan_attention = WanAttentionBlock( + dim=dim, + rngs=rngs + ) + dummy_input = jnp.ones(input_shape) + output = wan_attention(dummy_input) + assert output.shape == input_shape def test_wan_encode(self): key = jax.random.key(0) From 43253250cf6baa9a95be980d55091c67ce16b153 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 23:31:49 +0000 Subject: [PATCH 10/25] add wan mid block vae test --- .../models/wan/autoencoder_kl_wan.py | 14 ++++++++++ src/maxdiffusion/tests/wan_vae_test.py | 27 ++++++++++++------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 292264c2a..f9876e9c8 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -414,8 +414,20 @@ def __init__( num_layers: int = 1 ): self.dim = dim + resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) + resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)) + self.attentions = attentions + self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + x = self.resnets[0](x) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + x = resnet(x) return x class WanUpBlock(nnx.Module): @@ -519,6 +531,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): # (1, 1, 480, 720, 96) for layer in self.down_blocks: x = layer(x) + + x = self.mid_block(x) breakpoint() return x diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index f91f3932e..5cb799cdc 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -30,6 +30,7 @@ WanUpsample, AutoencoderKLWan, WanEncoder3d, + WanMidBlock, WanResidualBlock, WanRMS_norm, WanResample, @@ -373,6 +374,22 @@ def test_wan_attention(self): output = wan_attention(dummy_input) assert output.shape == input_shape + def test_wan_midblock(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch = 1 + t = 1 + dim = 384 + height = 60 + width = 90 + input_shape = (batch, t, height, width, dim) + wan_midblock = WanMidBlock( + dim=dim, rngs=rngs + ) + dummy_input = jnp.ones(input_shape) + output = wan_midblock(dummy_input) + assert output.shape == input_shape + def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -392,16 +409,6 @@ def test_wan_encode(self): attn_scales=attn_scales, temperal_downsample=temperal_downsample, ) - # wan_encoder = WanEncoder3d( - # rngs=rngs, - # dim=dim, - # z_dim=z_dim, - # dim_mult=dim_mult, - # num_res_blocks=num_res_blocks, - # attn_scales=attn_scales, - # temperal_downsample=temperal_downsample, - # non_linearity=nonlinearity - # ) batch = 1 channels = 3 t = 49 From a5e1e95b4c66185bedbabf02f2e2cdb7602da0d0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 28 Apr 2025 20:28:02 +0000 Subject: [PATCH 11/25] finishes vae encoder with matching shapes --- .../models/wan/autoencoder_kl_wan.py | 87 ++++++++++++++----- src/maxdiffusion/tests/wan_vae_test.py | 17 ++-- 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index f9876e9c8..a5678d028 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -289,6 +289,14 @@ def __init__( kernel_size=(1, 3, 3), stride=(1, 2, 2) ) + self.time_conv = WanCausalConv3d( + rngs=rngs, + in_channels = dim, + out_channels = dim, + kernel_size=(3, 1, 1), + stride=(2, 1, 1), + padding= (0, 0, 0) + ) else: self.resample = Identity() @@ -302,6 +310,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: h_new, w_new, c_new = x.shape[1:] x = x.reshape(n, d, h_new, w_new, c_new) + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = jnp.copy(x) + feat_idx[0] +=1 + else: + cache_x = jnp.copy(x[:, -1:, :, :, :]) + x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x class WanResidualBlock(nnx.Module): @@ -343,7 +363,6 @@ def __init__( def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): # Apply shortcut connection - #breakpoint() h = self.conv_shortcut(x) x = self.norm1(x) @@ -505,7 +524,8 @@ def __init__( if i != len(dim_mult) - 1: mode = "downsample3d" if temperal_downsample[i] else "downsample2d" self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) - + scale /= 2.0 + # middle_blocks self.mid_block = WanMidBlock( dim=out_dim, @@ -516,7 +536,12 @@ def __init__( ) # output blocks - self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs) + self.norm_out = WanRMS_norm( + out_dim, + channel_first=False, + images=False, + rngs=rngs + ) self.conv_out = WanCausalConv3d( rngs=rngs, in_channels=out_dim, @@ -526,14 +551,39 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - # (1, 1, 480, 720, 3) - x = self.conv_in(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv_in(x) # (1, 1, 480, 720, 96) for layer in self.down_blocks: - x = layer(x) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) - x = self.mid_block(x) - breakpoint() + x = self.mid_block(x, feat_cache, feat_idx) + + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv_out(x) return x class WanDecoder3d(nnx.Module): @@ -626,9 +676,7 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - breakpoint() x = self.conv_in(x) - breakpoint() return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -696,9 +744,7 @@ def _count_conv3d(module): count = 0 node_types = nnx.graph.iter_graph([module]) for path, value in node_types: - #breakpoint() if isinstance(value, WanCausalConv3d): - print("value: ", value) count +=1 return count @@ -711,6 +757,7 @@ def _count_conv3d(module): self._enc_feat_map = [None] * self._enc_conv_num def _encode(self, x: jax.Array): + self.clear_cache() if x.shape[-1] != 3: # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) @@ -721,6 +768,7 @@ def _encode(self, x: jax.Array): t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): + self._enc_conv_idx = [0] if i == 0: out = self.encoder( x[:, :1, :, :, :], @@ -729,18 +777,17 @@ def _encode(self, x: jax.Array): ) else: out_ = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx ) out = jnp.concatenate([out, out_], axis=1) - - # enc = self.quant_conv(out) - # mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] - # enc = jnp.concatenate([mu, logvar], dim=1) - # self.clear_cache() + enc = self.quant_conv(out) + mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] + enc = jnp.concatenate([mu, logvar], axis=-1) + self.clear_cache() # return enc - return x + return enc def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """ Encode video into latent distribution.""" @@ -748,5 +795,5 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode posterior = FlaxDiagonalGaussianDistribution(h) if not return_dict: return (posterior, ) - return FlaxAutoencoderKLOutput(latent_dict=posterior) + return FlaxAutoencoderKLOutput(latent_dist=posterior) \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 5cb799cdc..3a69aeb03 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -153,11 +153,11 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} - # def test_clear_cache(self): - # key = jax.random.key(0) - # rngs = nnx.Rngs(key) - # wan_vae = AutoencoderKLWan(rngs=rngs) - # wan_vae.clear_cache() + def test_clear_cache(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wan_vae = AutoencoderKLWan(rngs=rngs) + wan_vae.clear_cache() def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -394,12 +394,11 @@ def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dim = 96 - z_dim = 32 + z_dim = 16 dim_mult = [1, 2, 4, 4] num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - nonlinearity = "silu" wan_vae = AutoencoderKLWan( rngs=rngs, base_dim=dim, @@ -417,9 +416,7 @@ def test_wan_encode(self): input_shape = (batch, channels, t, height, width) input = jnp.ones(input_shape) output = wan_vae.encode(input) - breakpoint() - - + assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) if __name__ == "__main__": absltest.main() \ No newline at end of file From efe8528b02836723a183e8e43141c34a5c6ddf0d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 28 Apr 2025 20:55:07 +0000 Subject: [PATCH 12/25] add cache logic to modules. --- .../models/wan/autoencoder_kl_wan.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index a5678d028..ade16fe53 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -367,12 +367,33 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm1(x) x = self.nonlinearity(x) - x = self.conv1(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] <2 and feat_cache[idx] is not None: + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv1(x) x = self.norm2(x) x = self.nonlinearity(x) x = self.dropout(x) - x = self.conv2(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] <2 and feat_cache[idx] is not None: + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv2(x) return x + h @@ -442,7 +463,7 @@ def __init__( self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - x = self.resnets[0](x) + x = self.resnets[0](x, feat_cache, feat_idx) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: x = attn(x) From cf68754ef7efeb4263a887baa8c625e8dd370027 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 29 Apr 2025 18:32:23 +0000 Subject: [PATCH 13/25] adds decoder and checks matching resolutions. --- .../models/wan/autoencoder_kl_wan.py | 186 +++++++++++++----- src/maxdiffusion/tests/wan_vae_test.py | 33 ++++ 2 files changed, 174 insertions(+), 45 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index ade16fe53..060bda017 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -22,7 +22,11 @@ from ...configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ... import common_types -from ..vae_flax import FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution +from ..vae_flax import ( + FlaxAutoencoderKLOutput, + FlaxDiagonalGaussianDistribution, + FlaxDecoderOutput +) BlockSizes = common_types.BlockSizes @@ -82,7 +86,7 @@ def __init__( (0, 0) # Channel dimension - no padding ) - # Store the amount of padding needed *before* the depth dimension for caching logoic + # Store the amount of padding needed *before* the depth dimension for caching logic self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] self.conv = nnx.Conv( @@ -103,7 +107,6 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Arr # Ensure cache has same spatial/channel dims, potentially different depth assert cache_x.shape[0] == x.shape[0] and \ cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" - cache_len = cache_x.shape[1] x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) @@ -166,24 +169,13 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'): def __call__(self, x: jax.Array) -> jax.Array: input_dtype = x.dtype in_shape = x.shape - is_3d = len(in_shape) == 5 - n, d, h, w, c = in_shape if is_3d else(in_shape[0], 1, in_shape[1], in_shape[2], in_shape[3]) - + assert len(in_shape) == 4, "This module only takes tensors with shape of 4." + n, h, w, c = in_shape target_h = int(h * self.scale_factor[0]) target_w = int(w * self.scale_factor[1]) - - # jax.image.resize expects (..., H, W, C) - if is_3d: - x_reshaped = x.reshape(n * d, h, w, c) - out_reshaped = jax.image.resize(x_reshaped.astype(jnp.float32), - (n * d, target_h, target_w, c), - method=self.method) - out = out_reshaped.reshape(n, d, target_h, target_w, c) - else: # Asumming (N, H, W, C) - out = jax.image.resize(x.astype(jnp.float32), - (n, target_h, target_w, c), - method=self.method) - + out = jax.image.resize(x.astype(jnp.float32), + (n, target_h, target_w, c), + method=self.method) return out.astype(input_dtype) class Identity(nnx.Module): @@ -256,7 +248,7 @@ def __init__( ) elif mode == "upsample3d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), nnx.Conv( dim, dim // 2, @@ -305,6 +297,29 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: n, d, h, w, c = x.shape assert c == self.dim + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], dim=1) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(n, 2, d, h, w, c) + x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2) + x = x.reshape(n, d*2, h, w, c) + d = x.shape[1] x = x.reshape(n*d,h,w,c) x = self.resample(x) h_new, w_new, c_new = x.shape[1:] @@ -371,7 +386,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) - if cache_x.shape[1] <2 and feat_cache[idx] is not None: + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv1(x, feat_cache[idx]) @@ -387,7 +402,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) - if cache_x.shape[1] <2 and feat_cache[idx] is not None: + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x @@ -458,7 +473,7 @@ def __init__( attentions = [] for _ in range(num_layers): attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) - resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)) + resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)) self.attentions = attentions self.resnets = resnets @@ -467,7 +482,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: x = attn(x) - x = resnet(x) + x = resnet(x, feat_cache, feat_idx) return x class WanUpBlock(nnx.Module): @@ -482,19 +497,31 @@ def __init__( non_linearity: str = "silu" ): # Create layers list - self.resnets = [] + resnets = [] # Add residual blocks and attention if needed current_dim = in_dim for _ in range(num_res_blocks + 1): - self.resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) + resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) current_dim = out_dim + self.resnets = resnets # Add upsampling layer if needed. self.upsamplers = None if upsample_mode is not None: - self.upsamplers = WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs) + self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)] def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) return x class WanEncoder3d(nnx.Module): @@ -655,7 +682,13 @@ def __init__( ) # middle_blocks - self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) + self.mid_block = WanMidBlock( + dim=dims[0], + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1 + ) # upsample blocks self.up_blocks = [] @@ -668,7 +701,6 @@ def __init__( upsample_mode = None if i != len(dim_mult) - 1: upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - # Crete and add the upsampling block up_block = WanUpBlock( in_dim=in_dim, @@ -686,8 +718,7 @@ def __init__( scale *=2.0 # output blocks - self.norm_out = nnx.RMSNorm(num_features=out_dim, ) - self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs) + self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) self.conv_out = WanCausalConv3d( rngs=rngs, in_channels=out_dim, @@ -697,7 +728,39 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - x = self.conv_in(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T: , :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -723,6 +786,8 @@ def __init__( self.z_dim = z_dim self.temperal_downsample = temperal_downsample self.temporal_upsample = temperal_downsample[::-1] + self.latents_mean = latents_mean + self.latents_std = latents_std self.encoder = WanEncoder3d( rngs=rngs, @@ -747,16 +812,16 @@ def __init__( kernel_size=1, ) - # self.decoder = WanDecoder3d( - # rngs=rngs, - # dim=base_dim, - # z_dim=z_dim, - # dim_mult=dim_mult, - # num_res_blocks=num_res_blocks, - # attn_scales=attn_scales, - # temperal_upsample=self.temporal_upsample, - # dropout=dropout - # ) + self.decoder = WanDecoder3d( + rngs=rngs, + dim=base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temporal_upsample, + dropout=dropout + ) self.clear_cache() def clear_cache(self): @@ -769,9 +834,9 @@ def _count_conv3d(module): count +=1 return count - # self._conv_num = _count_conv3d(self.decoder) - # self._conv_idx = [0] - # self._feat_map = [None] * self._conv_num + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num # cache encode self._enc_conv_num = _count_conv3d(self.encoder) self._enc_conv_idx = [0] @@ -817,4 +882,35 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode if not return_dict: return (posterior, ) return FlaxAutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: + self.clear_cache() + iter_ = z.shape[1] + x = self.post_quant_conv(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + + out = jnp.concatenate([out, out_], axis=1) + + out = jnp.clip(out, a_min=-1.0, a_max=1.0) + self.clear_cache() + if not return_dict: + return (out, ) + + return FlaxDecoderOutput(sample=out) + + def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: + if z.shape[-1] != self.z_dim: + # reshape channel last for JAX + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + assert x.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {x.shape}" + decoded = self._decode(z).sample + if not return_dict: + return (decoded,) + return FlaxDecoderOutput(sample=decoded) + \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 3a69aeb03..a2dd50bff 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -390,6 +390,39 @@ def test_wan_midblock(self): output = wan_midblock(dummy_input) assert output.shape == input_shape + def test_wan_decode(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 96 + z_dim = 16 + dim_mult = [1, 2, 4, 4] + num_res_blocks = 2 + attn_scales = [] + temperal_downsample = [False, True, True] + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + ) + + batch = 1 + t = 13 + channels = 16 + height = 60 + width = 90 + input_shape = (batch, t, height, width, channels) + input = jnp.ones(input_shape) + + latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) + latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) + input = input / latents_std + latents_mean + dummy_output = wan_vae.decode(input) + assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) From cd16f28b7be3db0ee0b4af531046db9601277997 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 29 Apr 2025 18:36:10 +0000 Subject: [PATCH 14/25] run linter --- src/maxdiffusion/generate_wan.py | 235 +++--- src/maxdiffusion/image_processor.py | 71 +- src/maxdiffusion/models/attention_flax.py | 17 +- src/maxdiffusion/models/embeddings_flax.py | 1 + .../models/wan/autoencoder_kl_wan.py | 686 ++++++++---------- .../wan/transformers/transformer_flux_wan.py | 169 ++--- .../transformers/transformer_flux_wan_nnx.py | 71 +- .../pipelines/wan/pipeline_wan.py | 42 +- ...heduling_flow_match_euler_discrete_flax.py | 39 +- src/maxdiffusion/tests/wan_vae_test.py | 351 +++++---- src/maxdiffusion/video_processor.py | 156 ++-- 11 files changed, 877 insertions(+), 961 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e19d783de..3a79621d3 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import html from typing import Callable, List, Union, Sequence, Optional import time @@ -37,21 +38,23 @@ setup_initial_state, ) + def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text + text = whitespace_clean(basic_clean(text)) + return text + def _get_t5_prompt_embeds( tokenizer: AutoTokenizer, @@ -63,35 +66,36 @@ def _get_t5_prompt_embeds( dtype: Optional[torch.dtype] = None, ): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - batch_size = len(prompt) + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - - prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + return prompt_embeds - return prompt_embeds def encode_prompt( tokenizer: AutoTokenizer, @@ -106,77 +110,77 @@ def encode_prompt( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - """ - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - return prompt_embeds, negative_prompt_embeds + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + def run(config): max_logging.log("Wan 2.1 inference script") @@ -188,17 +192,15 @@ def run(config): global_batch_size = config.per_device_batch_size * jax.local_device_count() tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype + config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype ) text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="text_encoder", + config.pretrained_model_name_or_path, + subfolder="text_encoder", ) s0 = time.perf_counter() prompt_embeds, negative_prompt_embeds = encode_prompt( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=config.prompt, - negative_prompt=config.negative_prompt + tokenizer=tokenizer, text_encoder=text_encoder, prompt=config.prompt, negative_prompt=config.negative_prompt ) max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") @@ -209,20 +211,15 @@ def run(config): # ) # breakpoint() - pipeline, params = WanPipeline.from_pretrained( - config.pretrained_model_name_or_path, - vae=None, - transformer=None - ) - - #wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) - + pipeline, params = WanPipeline.from_pretrained(config.pretrained_model_name_or_path, vae=None, transformer=None) + # wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) + if __name__ == "__main__": app.run(main) diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 2f99d0f7f..76fa7635e 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -35,51 +35,52 @@ List[torch.FloatTensor], ] + def is_valid_image(image) -> bool: - r""" - Checks if the input is a valid image. + r""" + Checks if the input is a valid image. - A valid image can be: - - A `PIL.Image.Image`. - - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). - Args: - image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): - The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. - Returns: - `bool`: - `True` if the input is a valid image, `False` otherwise. - """ - return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) def is_valid_image_imagelist(images): - r""" - Checks if the input is a valid image or list of images. + r""" + Checks if the input is a valid image or list of images. - The input can be one of the following formats: - - A 4D tensor or numpy array (batch of images). - - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or - `torch.Tensor`. - - A list of valid images. + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. - Args: - images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): - The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid - images. + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. - Returns: - `bool`: - `True` if the input is valid, `False` otherwise. - """ - if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: - return True - elif is_valid_image(images): - return True - elif isinstance(images, list): - return all(is_valid_image(image) for image in images) - return False + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False class VaeImageProcessor(ConfigMixin): diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index fc838c9db..2f8946056 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -405,6 +405,7 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back + def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) @@ -414,6 +415,7 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + class FlaxWanAttention(nn.Module): query_dim: int heads: int = 8 @@ -433,7 +435,7 @@ class FlaxWanAttention(nn.Module): out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) precision: jax.lax.Precision = None qkv_bias: bool = False - + def setup(self): if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") @@ -509,13 +511,13 @@ def setup(self): precision=self.precision, ) self.dropout_layer = nn.Dropout(rate=self.dropout) - + def call( - self, - hidden_states: Array, - encoder_hidden_states: Optional[Array], - rotary_emb: Optional[Array], - deterministic: bool = True + self, + hidden_states: Array, + encoder_hidden_states: Optional[Array], + rotary_emb: Optional[Array], + deterministic: bool = True, ): encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states @@ -539,6 +541,7 @@ def call( hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) return self.dropout_layer(hidden_states, deterministic=deterministic) + class FlaxFluxAttention(nn.Module): query_dim: int heads: int = 8 diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 3377ef6cc..cc961e131 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -56,6 +56,7 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal + class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 060bda017..2cf43f69c 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -22,23 +22,14 @@ from ...configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ... import common_types -from ..vae_flax import ( - FlaxAutoencoderKLOutput, - FlaxDiagonalGaussianDistribution, - FlaxDecoderOutput -) +from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) BlockSizes = common_types.BlockSizes CACHE_T = 2 -_ACTIVATIONS = { - "swish": jax.nn.silu, - "silu": jax.nn.silu, - "relu": jax.nn.relu, - "gelu": jax.nn.gelu, - "mish": jax.nn.mish -} +_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish} + def get_activation(name: str): func = _ACTIVATIONS.get(name) @@ -46,111 +37,115 @@ def get_activation(name: str): raise ValueError(f"Unknown activation function: {name}") return func + # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: - """Canonicalizes a value to a tuple of integers.""" - if isinstance(x, int): - return (x,) * rank - elif isinstance(x, Sequence) and len(x) == rank: - return tuple(x) - else: - raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") + """Canonicalizes a value to a tuple of integers.""" + if isinstance(x, int): + return (x,) * rank + elif isinstance(x, Sequence) and len(x) == rank: + return tuple(x) + else: + raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") + class WanCausalConv3d(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, # rngs are required for initializing parameters, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, - padding: Union[int, Tuple[int, int, int]] = 0, - use_bias: bool = True, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + rngs: nnx.Rngs, # rngs are required for initializing parameters, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + use_bias: bool = True, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): - self.kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') - self.stride = _canonicalize_tuple(stride, 3, 'stride') - padding_tuple = _canonicalize_tuple(padding, 3, 'padding') # (D, H, W) padding amounts + self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") + self.stride = _canonicalize_tuple(stride, 3, "stride") + padding_tuple = _canonicalize_tuple(padding, 3, "padding") # (D, H, W) padding amounts self._causal_padding = ( - (0, 0), # Batch dimension - no padding - (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) - (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding - (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding - (0, 0) # Channel dimension - no padding + (0, 0), # Batch dimension - no padding + (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) + (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding + (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding + (0, 0), # Channel dimension - no padding ) # Store the amount of padding needed *before* the depth dimension for caching logic - self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] self.conv = nnx.Conv( - in_features=in_channels, - out_features=out_channels, - kernel_size=self.kernel_size, - strides=self.stride, - use_bias=use_bias, - padding='VALID', # Handle padding manually - rngs=rngs + in_features=in_channels, + out_features=out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + use_bias=use_bias, + padding="VALID", # Handle padding manually + rngs=rngs, ) - + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Array: - current_padding = list(self._causal_padding) # Mutable copy + current_padding = list(self._causal_padding) # Mutable copy padding_needed = self._depth_padding_before if cache_x is not None and padding_needed > 0: # Ensure cache has same spatial/channel dims, potentially different depth - assert cache_x.shape[0] == x.shape[0] and \ - cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" + assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" cache_len = cache_x.shape[1] - x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) + x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) padding_needed -= cache_len if padding_needed < 0: # Cache longer than needed padding, trim from start x = x[:, -padding_needed:, ...] - current_padding[1] = (0, 0) # No explicit padding needed now + current_padding[1] = (0, 0) # No explicit padding needed now else: # Update depth padding needed current_padding[1] = (padding_needed, 0) - + # Apply padding if any dimension requires it padding_to_apply = tuple(current_padding) if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): - x_padded = jnp.pad(x, padding_to_apply, mode='constant', constant_values=0.0) + x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: x_padded = x - + out = self.conv(x_padded) return out + class WanRMS_norm(nnx.Module): + def __init__( - self, - dim: int, - rngs: nnx.Rngs, - channel_first: bool = True, - images: bool = True, - eps: float = 1e-6, - use_bias: bool = False, + self, + dim: int, + rngs: nnx.Rngs, + channel_first: bool = True, + images: bool = True, + eps: float = 1e-6, + use_bias: bool = False, ): broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim, ) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.eps = eps self.channel_first = channel_first - self.scale = dim ** 0.5 + self.scale = dim**0.5 # Initialize gamma as parameter self.gamma = nnx.Param(jnp.ones(shape)) if use_bias: self.bias = nnx.Param(jnp.zeros(shape)) else: self.bias = 0 - + def __call__(self, x: jax.Array) -> jax.Array: normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True) normalized = x / jnp.maximum(normalized, self.eps) @@ -159,13 +154,15 @@ def __call__(self, x: jax.Array) -> jax.Array: return normalized + self.bias.value return normalized + class WanUpsample(nnx.Module): - def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'): + + def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"): # scale_factor for (H, W) # JAX resize works on spatial dims, H, W assumming (N, D, H, W, C) or (N, H, W, C) self.scale_factor = scale_factor self.method = method - + def __call__(self, x: jax.Array) -> jax.Array: input_dtype = x.dtype in_shape = x.shape @@ -173,62 +170,58 @@ def __call__(self, x: jax.Array) -> jax.Array: n, h, w, c = in_shape target_h = int(h * self.scale_factor[0]) target_w = int(w * self.scale_factor[1]) - out = jax.image.resize(x.astype(jnp.float32), - (n, target_h, target_w, c), - method=self.method) + out = jax.image.resize(x.astype(jnp.float32), (n, target_h, target_w, c), method=self.method) return out.astype(input_dtype) + class Identity(nnx.Module): + def __call__(self, x): return x + class ZeroPaddedConv2D(nnx.Module): """ Module for adding padding before conv. Currently it does not add any padding. """ + def __init__( - self, - dim: int, - rngs: nnx.Rngs, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + dim: int, + rngs: nnx.Rngs, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): - kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') - stride = _canonicalize_tuple(stride, 3, 'stride') - self.conv = nnx.Conv( - dim, - dim, - kernel_size=kernel_size, - strides=stride, - use_bias=True, - rngs=rngs - ) - + kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") + stride = _canonicalize_tuple(stride, 3, "stride") + self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) + def __call__(self, x): return self.conv(x) class WanResample(nnx.Module): + def __init__( - self, - dim: int, - mode: str, - rngs: nnx.Rngs, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + dim: int, + mode: str, + rngs: nnx.Rngs, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): self.dim = dim self.mode = mode @@ -236,58 +229,43 @@ def __init__( if mode == "upsample2d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), - nnx.Conv( - dim, - dim // 2, - kernel_size=(1, 3, 3), - padding='SAME', - use_bias=True, - rngs=rngs, - ) + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding="SAME", + use_bias=True, + rngs=rngs, + ), ) elif mode == "upsample3d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), - nnx.Conv( - dim, - dim // 2, - kernel_size=(1, 3, 3), - padding='SAME', - use_bias=True, - rngs=rngs, - ) + WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding="SAME", + use_bias=True, + rngs=rngs, + ), ) self.time_conv = WanCausalConv3d( - rngs=rngs, - in_channels=dim, - out_channels=dim * 2, - kernel_size=(3, 1, 1), - padding=(1, 0, 0), + rngs=rngs, + in_channels=dim, + out_channels=dim * 2, + kernel_size=(3, 1, 1), + padding=(1, 0, 0), ) elif mode == "downsample2d": # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(1, 3, 3), - stride=(1, 2, 2) - ) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) elif mode == "downsample3d": # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(1, 3, 3), - stride=(1, 2, 2) - ) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) self.time_conv = WanCausalConv3d( - rngs=rngs, - in_channels = dim, - out_channels = dim, - kernel_size=(3, 1, 1), - stride=(2, 1, 1), - padding= (0, 0, 0) + rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) ) else: self.resample = Identity() @@ -318,9 +296,9 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: feat_idx[0] += 1 x = x.reshape(n, 2, d, h, w, c) x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2) - x = x.reshape(n, d*2, h, w, c) + x = x.reshape(n, d * 2, h, w, c) d = x.shape[1] - x = x.reshape(n*d,h,w,c) + x = x.reshape(n * d, h, w, c) x = self.resample(x) h_new, w_new, c_new = x.shape[1:] x = x.reshape(n, d, h_new, w_new, c_new) @@ -330,7 +308,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: idx = feat_idx[0] if feat_cache[idx] is None: feat_cache[idx] = jnp.copy(x) - feat_idx[0] +=1 + feat_idx[0] += 1 else: cache_x = jnp.copy(x[:, -1:, :, :, :]) x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) @@ -338,8 +316,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: feat_idx[0] += 1 return x - + + class WanResidualBlock(nnx.Module): + def __init__( self, in_dim: int, @@ -352,29 +332,15 @@ def __init__( # layers self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) - self.conv1 = WanCausalConv3d( - rngs=rngs, - in_channels=in_dim, - out_channels=out_dim, - kernel_size=3, - padding=1 - ) + self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1) self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) self.dropout = nnx.Dropout(dropout, rngs=rngs) - self.conv2 = WanCausalConv3d( - rngs=rngs, - in_channels=out_dim, - out_channels=out_dim, - kernel_size=3, - padding=1 + self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.conv_shortcut = ( + WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1) + if in_dim != out_dim + else Identity() ) - self.conv_shortcut = WanCausalConv3d( - rngs=rngs, - in_channels=in_dim, - out_channels=out_dim, - kernel_size=1 - ) if in_dim != out_dim else Identity() - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): # Apply shortcut connection @@ -391,7 +357,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv1(x) @@ -406,50 +372,38 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv2(x) return x + h + class WanAttentionBlock(nnx.Module): - def __init__( - self, - dim: int, - rngs: nnx.Rngs - ): + + def __init__(self, dim: int, rngs: nnx.Rngs): self.dim = dim self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) - self.to_qkv = nnx.Conv( - in_features=dim, - out_features=dim * 3, - kernel_size=1, - rngs=rngs - ) - self.proj = nnx.Conv( - in_features=dim, - out_features=dim, - kernel_size=1, - rngs=rngs - ) - + self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=1, rngs=rngs) + self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=1, rngs=rngs) + def __call__(self, x: jax.Array): batch_size, time, height, width, channels = x.shape identity = x - + x = x.reshape(batch_size * time, height, width, channels) x = self.norm(x) - qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) + qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - qkv = qkv.reshape(batch_size*time, 1, channels * 3, -1) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) q, k, v = jnp.split(qkv, 3, axis=-1) x = jax.nn.dot_product_attention(q, k, v) x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) - #output projection + # output projection x = self.proj(x) # Reshape back @@ -458,25 +412,18 @@ def __call__(self, x: jax.Array): return x + identity - class WanMidBlock(nnx.Module): - def __init__( - self, - dim: int, - rngs: nnx.Rngs, - dropout: float = 0.0, - non_linearity: str = "silu", - num_layers: int = 1 - ): + + def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): self.dim = dim - resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)] + resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)] attentions = [] for _ in range(num_layers): attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)) self.attentions = attentions self.resnets = resnets - + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.resnets[0](x, feat_cache, feat_idx) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -485,38 +432,42 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = resnet(x, feat_cache, feat_idx) return x + class WanUpBlock(nnx.Module): + def __init__( - self, - in_dim: int, - out_dim: int, - num_res_blocks: int, - rngs: nnx.Rngs, - dropout: float = 0.0, - upsample_mode: Optional[str] = None, - non_linearity: str = "silu" + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", ): # Create layers list resnets = [] # Add residual blocks and attention if needed current_dim = in_dim for _ in range(num_res_blocks + 1): - resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) + resnets.append( + WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs) + ) current_dim = out_dim self.resnets = resnets - + # Add upsampling layer if needed. self.upsamplers = None if upsample_mode is not None: self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)] - + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): for resnet in self.resnets: if feat_cache is not None: x = resnet(x, feat_cache, feat_idx) else: x = resnet(x) - + if self.upsamplers is not None: if feat_cache is not None: x = self.upsamplers[0](x, feat_cache, feat_idx) @@ -524,18 +475,20 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.upsamplers[0](x) return x + class WanEncoder3d(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, - dim: int =128, - z_dim: int = 4, - dim_mult = [1, 2, 4, 4], - num_res_blocks = 2, - attn_scales = [], - temperal_downsample = [True, True, False], - dropout = 0.0, - non_linearity: str = 'silu', + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", ): self.dim = dim self.z_dim = z_dim @@ -551,11 +504,11 @@ def __init__( # init block self.conv_in = WanCausalConv3d( - rngs=rngs, - in_channels=3, - out_channels=dims[0], - kernel_size=3, - padding=1, + rngs=rngs, + in_channels=3, + out_channels=dims[0], + kernel_size=3, + padding=1, ) # downsample blocks @@ -567,37 +520,26 @@ def __init__( if scale in attn_scales: self.down_blocks.append(WanAttentionBlock(dim=out_dim, rngs=rngs)) in_dim = out_dim - + # downsample block if i != len(dim_mult) - 1: mode = "downsample3d" if temperal_downsample[i] else "downsample2d" self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) scale /= 2.0 - + # middle_blocks self.mid_block = WanMidBlock( - dim=out_dim, - rngs=rngs, - dropout=dropout, - non_linearity=non_linearity, - num_layers=1, + dim=out_dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, ) # output blocks - self.norm_out = WanRMS_norm( - out_dim, - channel_first=False, - images=False, - rngs=rngs - ) - self.conv_out = WanCausalConv3d( - rngs=rngs, - in_channels=out_dim, - out_channels=z_dim, - kernel_size=3, - padding=1 - ) - + self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs) + self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] @@ -607,7 +549,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_in(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv_in(x) # (1, 1, 480, 720, 96) @@ -616,7 +558,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = layer(x, feat_cache, feat_idx) else: x = layer(x) - + x = self.mid_block(x, feat_cache, feat_idx) x = self.norm_out(x) @@ -629,11 +571,12 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_out(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv_out(x) return x + class WanDecoder3d(nnx.Module): r""" A 3D decoder module. @@ -647,17 +590,18 @@ class WanDecoder3d(nnx.Module): dropout (float): Dropout rate for the dropout layers. non_linearity (str): Type of non-linearity to use. """ + def __init__( - self, - rngs: nnx.Rngs, - dim: int = 128, - z_dim: int = 128, - dim_mult: List[int] = [1, 2, 4, 4], - num_res_blocks: int = 2, - attn_scales = List[float], - temperal_upsample=[False, True, True], - dropout=0.0, - non_linearity: str = "silu", + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 128, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales=List[float], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", ): self.dim = dim self.z_dim = z_dim @@ -673,22 +617,10 @@ def __init__( scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block - self.conv_in = WanCausalConv3d( - rngs=rngs, - in_channels=z_dim, - out_channels=dims[0], - kernel_size=3, - padding=1 - ) + self.conv_in = WanCausalConv3d(rngs=rngs, in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1) # middle_blocks - self.mid_block = WanMidBlock( - dim=dims[0], - rngs=rngs, - dropout=dropout, - non_linearity=non_linearity, - num_layers=1 - ) + self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) # upsample blocks self.up_blocks = [] @@ -696,41 +628,35 @@ def __init__( # residual (+attention) blocks if i > 0: in_dim = in_dim // 2 - + # Determine if we need upsampling upsample_mode = None if i != len(dim_mult) - 1: upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" # Crete and add the upsampling block up_block = WanUpBlock( - in_dim=in_dim, - out_dim=out_dim, - num_res_blocks=num_res_blocks, - dropout=dropout, - upsample_mode=upsample_mode, - non_linearity=non_linearity, - rngs=rngs + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + rngs=rngs, ) self.up_blocks.append(up_block) # Update scale for next iteration if upsample_mode is not None: - scale *=2.0 - + scale *= 2.0 + # output blocks self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) - self.conv_out = WanCausalConv3d( - rngs=rngs, - in_channels=out_dim, - out_channels=3, - kernel_size=3, - padding=1 - ) - + self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=3, kernel_size=3, padding=1) + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - cache_x = jnp.copy(x[:, -CACHE_T: , :, :, :]) + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) @@ -739,10 +665,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_in(x) - + ## middle x = self.mid_block(x, feat_cache, feat_idx) - + ## upsamples for up_block in self.up_blocks: x = up_block(x, feat_cache, feat_idx) @@ -763,25 +689,55 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.conv_out(x) return x + class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( - self, - rngs: nnx.Rngs, - base_dim: int = 96, - z_dim: int = 16, - dim_mult: Tuple[int] = [1,2,4,4], - num_res_blocks: int = 2, - attn_scales: List[float] = [], - temperal_downsample: List[bool] = [False, True, True], - dropout: float = 0.0, - latents_mean: List[float] = [ - -0.7571,-0.7089,-0.9113,0.1075,-0.1745,0.9653,-0.1517, 1.5508, - 0.4134,-0.0715,0.5517,-0.3632,-0.1922,-0.9497,0.2503,-0.2921, - ], - latents_std: List[float] = [ - 2.8184,1.4541,2.3275,2.6558,1.2196,1.7708,2.6052,2.0743, - 3.2687,2.1526,2.8652,1.5579,1.6382,1.1253,2.8251,1.9160, - ], + self, + rngs: nnx.Rngs, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], ): self.z_dim = z_dim self.temperal_downsample = temperal_downsample @@ -790,48 +746,44 @@ def __init__( self.latents_std = latents_std self.encoder = WanEncoder3d( - rngs=rngs, - dim=base_dim, - z_dim=z_dim * 2, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - dropout=dropout, - ) - self.quant_conv = WanCausalConv3d( - rngs=rngs, - in_channels=z_dim * 2, - out_channels=z_dim * 2, - kernel_size=1 + rngs=rngs, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, ) + self.quant_conv = WanCausalConv3d(rngs=rngs, in_channels=z_dim * 2, out_channels=z_dim * 2, kernel_size=1) self.post_quant_conv = WanCausalConv3d( - rngs=rngs, - in_channels=z_dim, - out_channels=z_dim, - kernel_size=1, + rngs=rngs, + in_channels=z_dim, + out_channels=z_dim, + kernel_size=1, ) self.decoder = WanDecoder3d( - rngs=rngs, - dim=base_dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_upsample=self.temporal_upsample, - dropout=dropout + rngs=rngs, + dim=base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temporal_upsample, + dropout=dropout, ) self.clear_cache() - + def clear_cache(self): - """ Resets cache dictionaries and indices""" + """Resets cache dictionaries and indices""" + def _count_conv3d(module): count = 0 node_types = nnx.graph.iter_graph([module]) for path, value in node_types: if isinstance(value, WanCausalConv3d): - count +=1 + count += 1 return count self._conv_num = _count_conv3d(self.decoder) @@ -848,24 +800,18 @@ def _encode(self, x: jax.Array): # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" - - #self.clear_cache() + + # self.clear_cache() t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder( - x[:, :1, :, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx - ) + out = self.encoder(x[:, :1, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: out_ = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx ) out = jnp.concatenate([out, out_], axis=1) enc = self.quant_conv(out) @@ -875,14 +821,16 @@ def _encode(self, x: jax.Array): # return enc return enc - def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: - """ Encode video into latent distribution.""" + def encode( + self, x: jax.Array, return_dict: bool = True + ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: + """Encode video into latent distribution.""" h = self._encode(x) posterior = FlaxDiagonalGaussianDistribution(h) if not return_dict: - return (posterior, ) + return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) - + def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: self.clear_cache() iter_ = z.shape[1] @@ -890,16 +838,16 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: - out_ = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = jnp.concatenate([out, out_], axis=1) - + out = jnp.clip(out, a_min=-1.0, a_max=1.0) self.clear_cache() if not return_dict: - return (out, ) + return (out,) return FlaxDecoderOutput(sample=out) @@ -912,5 +860,3 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut if not return_dict: return (decoded,) return FlaxDecoderOutput(sample=decoded) - - \ No newline at end of file diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py index cdcecf80b..5cd83bbbd 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py @@ -22,16 +22,12 @@ import flax.linen as nn from ...attention_flax import FlaxFeedForward, Fla -from ...embeddings_flax import ( - get_1d_rotary_pos_embed, - FlaxTimesteps, - FlaxTimestepEmbedding, - PixArtAlphaTextProjection -) +from ...embeddings_flax import (get_1d_rotary_pos_embed, FlaxTimesteps, FlaxTimestepEmbedding, PixArtAlphaTextProjection) from ....configuration_utils import ConfigMixin, flax_register_to_config from ...modeling_flax_utils import FlaxModelMixin + class WanRotaryPosEmbed(nn.Module): attention_head_dim: int patch_size: Tuple[int, int, int] @@ -54,9 +50,9 @@ def __call__(self, hidden_states: Array) -> Array: self.freqs = jnp.concatenate(freqs, dim=1) sizes = [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6 + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, ] cumulative_sizes = jnp.cumsum(jnp.array(sizes)) split_indices = cumulative_sizes[:-1] @@ -76,6 +72,7 @@ def __call__(self, hidden_states: Array) -> Array: return freqs_final + class WanImageEmbeddings(nn.Module): out_features: int dtype: jnp.dtype = jnp.float32 @@ -85,18 +82,15 @@ class WanImageEmbeddings(nn.Module): @nn.compact def __call__(self, encoder_hidden_states_image: Array) -> Array: hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, + dtype=jnp.float32, + param_dtype=jnp.float32, )(encoder_hidden_states_image) hidden_states = FlaxFeedForward( - self.out_features, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + self.out_features, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision )(hidden_states) hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, + dtype=jnp.float32, + param_dtype=jnp.float32, )(hidden_states) return hidden_states @@ -113,43 +107,37 @@ class WanTimeTextImageEmbeddings(nn.Module): @nn.compact def __call__(self, timestep: Array, encoder_hidden_states: Array, encoder_hidden_states_image: Array) -> Array: - + timestep = FlaxTimesteps( - dim=self.time_freq_dim, - flip_sin_to_cos=True, - freq_shift=0, - )(timestep) - temb = FlaxTimestepEmbedding( - time_embed_dim=self.dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype + dim=self.time_freq_dim, + flip_sin_to_cos=True, + freq_shift=0, )(timestep) + temb = FlaxTimestepEmbedding(time_embed_dim=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype)(timestep) timestep_proj = nn.Dense( - self.time_proj_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, + self.time_proj_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, )(nn.silu(temb)) encoder_hidden_states = PixArtAlphaTextProjection( - hidden_size=self.dim, - act_fn="gelu_tanh", - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + hidden_size=self.dim, + act_fn="gelu_tanh", + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, )(encoder_hidden_states) if encoder_hidden_states_image is not None: encoder_hidden_states_image = WanImageEmbeddings( - out_features=self.dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + out_features=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision )(encoder_hidden_states_image) - + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + class WanTransformerBlock(nn.Module): dim: int ffn_dim: int @@ -160,38 +148,36 @@ class WanTransformerBlock(nn.Module): added_kv_proj_dim: Optional[int] = None @nn.compact - def __call__( - self, - hidden_states: Array, - encoder_hidden_states: Array, - temb: Array, - rotary_emb: Array - ): + def __call__(self, hidden_states: Array, encoder_hidden_states: Array, temb: Array, rotary_emb: Array): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) # 1. Self-attention - norm_hidden_states = (nn.LayerNorm( - epsilon=self.eps, - use_bias=False, - use_scale=False, - dtype=jnp.float32, - param_dtype=jnp.float32, - )(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) + norm_hidden_states = ( + nn.LayerNorm( + epsilon=self.eps, + use_bias=False, + use_scale=False, + dtype=jnp.float32, + param_dtype=jnp.float32, + )(hidden_states.astype(jnp.float32)) + * (1 + scale_msa) + + shift_msa + ).astype(hidden_states.dtype) attn_output = FlaxWanAttention( - query_dim=self.dim, - heads=self.num_heads, - dim_head=self.dim // self.num_heads, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - attention_kernel=self.attention_kernel, - mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes, - + query_dim=self.dim, + heads=self.num_heads, + dim_head=self.dim // self.num_heads, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + flash_block_sizes=self.flash_block_sizes, ) + class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A Transformer model for video-like data used in the Wan model. @@ -254,15 +240,15 @@ class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): @nn.compact def __call__( - self, - hidden_states: Array, - timestep: Array, - encoder_hidden_states: Array, - encoder_hidden_states_image: Optional[Array] = None, - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None + self, + hidden_states: Array, + timestep: Array, + encoder_hidden_states: Array, + encoder_hidden_states_image: Optional[Array] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[Array, Dict[str, Array]]: - + inner_dim = self.num_attention_heads * self.attention_head_dim batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -273,32 +259,29 @@ def __call__( # 1. Patch & position embedding rotary_emb = WanRotaryPosEmbed( - attention_head_dim=self.attention_head_dim, - patch_size=self.patch_size, - max_seq_len=self.rope_max_seq_len + attention_head_dim=self.attention_head_dim, patch_size=self.patch_size, max_seq_len=self.rope_max_seq_len )(hidden_states) hidden_states = nn.Conv( - features=inner_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, + features=inner_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, )(hidden_states) - flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? + flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? flattened = hidden_states.reshape(flattened_shape) transposed = jnp.transpose(flattened, (0, 2, 1)) # 2. Condition embeddings # image_embedding_dim=1280 for I2V model temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = WanTimeTextImageEmbeddings( - dim=inner_dim, - time_freq_dim=self.freq_dim, - time_proj_dim=inner_dim * 6, - text_embed_dim=self.text_dim, - image_embed_dim=self.image_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + dim=inner_dim, + time_freq_dim=self.freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=self.text_dim, + image_embed_dim=self.image_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, )(timestep, encoder_hidden_states, encoder_hidden_states_image) - diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py index e5c1b7145..eedba51cc 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py @@ -23,46 +23,45 @@ BlockSizes = common_types.BlockSizes + class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( - self, - rngs: nnx.Rngs, - model_type='t2v', - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2038, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - eps=1e-6, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + rngs: nnx.Rngs, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2038, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): self.path_embedding = nnx.Conv( - in_dim, - dim, - kernel_size=patch_size, - strides=patch_size, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), - ("batch",) - ), - rngs=rngs + in_dim, + dim, + kernel_size=patch_size, + strides=patch_size, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)), + rngs=rngs, ) def __call__(self, x): diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py index 6bd9fb09b..27b5844a8 100644 --- a/src/maxdiffusion/pipelines/wan/pipeline_wan.py +++ b/src/maxdiffusion/pipelines/wan/pipeline_wan.py @@ -18,10 +18,11 @@ from transformers import AutoTokenizer, UMT5EncoderModel import torch from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel -from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan from ..pipeline_flax_utils import FlaxDiffusionPipeline from ...video_processor import VideoProcessor -#from ...schedulers import FlowMatchEulerDiscreteScheduler +# from ...schedulers import FlowMatchEulerDiscreteScheduler + class WanPipeline(FlaxDiffusionPipeline): @@ -31,40 +32,39 @@ def __init__( text_encoder: UMT5EncoderModel, transformer: WanModel, vae: AutoencoderKLWan, - #scheduler: FlowMatchEulerDiscreteScheduler, + # scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - #scheduler=scheduler + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + # scheduler=scheduler ) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - def _get_t5_prompt_embds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() @@ -81,4 +81,4 @@ def _get_t5_prompt_embds( prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return prompt_embeds \ No newline at end of file + return prompt_embeds diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py index db96f53fd..ec29a353e 100644 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py @@ -29,14 +29,17 @@ broadcast_to_shape_from_left, ) + @flax.struct.dataclass class FlowMatchEulerDiscreteSchedulerState: common: CommonSchedulerState + @dataclass class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): state: FlowMatchEulerDiscreteSchedulerState + class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] @@ -45,27 +48,27 @@ class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): @property def has_state(self): return True - + @register_to_config def __init__( - self, - num_train_timesteps: int = 1000, - shift: float = 1.0, - use_dynamic_shifting: bool = False, - base_shift: Optional[float] = 0.5, - max_shift: Optional[float] = 1.15, - base_image_seq_len: Optional[int] = 256, - max_image_seq_len: Optional[int] = 4096, - invert_sigmas: bool = False, - shift_terminal: Optional[float] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - time_shift_type: str = "exponential", - dtype: jnp.dtype = jnp.float32 + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype - + def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: if common is None: - common = CommonSchedulerState.create(self) \ No newline at end of file + common = CommonSchedulerState.create(self) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index a2dd50bff..98e49f130 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -26,133 +26,135 @@ import pytest from absl.testing import absltest from ..models.wan.autoencoder_kl_wan import ( - WanCausalConv3d, - WanUpsample, - AutoencoderKLWan, - WanEncoder3d, - WanMidBlock, - WanResidualBlock, - WanRMS_norm, - WanResample, - ZeroPaddedConv2D, - WanAttentionBlock + WanCausalConv3d, + WanUpsample, + AutoencoderKLWan, + WanEncoder3d, + WanMidBlock, + WanResidualBlock, + WanRMS_norm, + WanResample, + ZeroPaddedConv2D, + WanAttentionBlock, ) CACHE_T = 2 + class TorchWanRMS_norm(nn.Module): - r""" - A custom RMS normalization layer. - - Args: - dim (int): The number of dimensions to normalize over. - channel_first (bool, optional): Whether the input tensor has channels as the first dimension. - Default is True. - images (bool, optional): Whether the input represents image data. Default is True. - bias (bool, optional): Whether to include a learnable bias term. Default is False. - """ - - def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: - super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - - self.channel_first = channel_first - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 - - def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + r""" + A custom RMS normalization layer. -class TorchWanResample(nn.Module): - r""" - A custom resampling module for 2D and 3D data. - - Args: - dim (int): The number of input/output channels. - mode (str): The resampling mode. Must be one of: - - 'none': No resampling (identity operation). - - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. - - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. - - 'downsample2d': 2D downsampling with zero-padding and convolution. - - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. - """ - - def __init__(self, dim: int, mode: str) -> None: - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == "upsample2d": - self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - elif mode == "upsample3d": - self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - - elif mode == "downsample2d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class TorchWanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 else: - self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() - if self.mode == "upsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 - else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) - if feat_cache[idx] == "Rep": - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.resample(x) - x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) - - if self.mode == "downsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - return x + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + class WanVaeTest(unittest.TestCase): + def setUp(self): WanVaeTest.dummy_data = {} - + def test_clear_cache(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -161,7 +163,7 @@ def test_clear_cache(self): def test_wanrms_norm(self): """Test against the Pytorch implementation""" - + # --- Test Case 1: images == True --- dim = 96 input_shape = (1, 96, 1, 480, 720) @@ -170,7 +172,7 @@ def test_wanrms_norm(self): input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() - + key = jax.random.key(0) rngs = nnx.Rngs(key) wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs) @@ -184,7 +186,7 @@ def test_wanrms_norm(self): input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() - + key = jax.random.key(0) rngs = nnx.Rngs(key) wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs, images=False) @@ -192,7 +194,7 @@ def test_wanrms_norm(self): output = wanrms_norm(dummy_input) output_np = np.array(output) assert np.allclose(output_np, torch_output_np) == True - + def test_zero_padded_conv(self): key = jax.random.key(0) @@ -200,19 +202,14 @@ def test_zero_padded_conv(self): dim = 96 kernel_size = 3 - stride= (2, 2) + stride = (2, 2) resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride)) input_shape = (1, 96, 480, 720) input = torch.ones(input_shape) output_torch = resample(input) assert output_torch.shape == (1, 96, 240, 360) - model = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(1, 3, 3), - stride=(1, 2, 2) - ) + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) dummy_input = jnp.ones(input_shape) dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) output = model(dummy_input) @@ -220,7 +217,7 @@ def test_zero_padded_conv(self): assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): - batch_size=1 + batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 @@ -237,59 +234,49 @@ def test_wan_upsample(self): # --- Test Case 1: depth == 1 --- output = upsample(dummy_input) assert output.shape == (1, 1, 64, 64, 3) - + def test_wan_resample(self): # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity key = jax.random.key(0) rngs = nnx.Rngs(key) - + # --- Test Case 1: downsample2d --- batch = 1 dim = 96 t = 1 h = 480 w = 720 - mode="downsample2d" + mode = "downsample2d" input_shape = (batch, dim, t, h, w) expected_output_shape = (1, dim, 1, 240, 360) # output dim should be (1, 96, 1, 480, 720) dummy_input = torch.ones(input_shape) - torch_wan_resample = TorchWanResample( - dim=dim, - mode=mode - ) + torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) - assert torch_output.shape == (batch, dim, t, h // 2, w //2) + assert torch_output.shape == (batch, dim, t, h // 2, w // 2) - wan_resample = WanResample( - dim, - mode=mode, - rngs=rngs - ) + wan_resample = WanResample(dim, mode=mode, rngs=rngs) # channels is always last here input_shape = (batch, t, h, w, dim) dummy_input = jnp.ones(input_shape) output = wan_resample(dummy_input) - assert output.shape == (batch, t, h//2, h//2, dim) + assert output.shape == (batch, t, h // 2, h // 2, dim) breakpoint() - + # --- Test Case 1: downsample3d --- dim = 192 input_shape = (1, dim, 1, 240, 360) - torch_wan_resample = WanResample( - dim=dim, - mode="downsample3d" - ) + torch_wan_resample = WanResample(dim=dim, mode="downsample3d") def test_3d_conv(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - batch_size=1 + batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 out_channels = 16 - kernel_d, kernel_h, kernel_w = 3, 3, 3 # Kernel size (Depth, Height, Width) - padding_d, padding_h, padding_w = 1, 1, 1 # Base padding (Depth, Height, Width) + kernel_d, kernel_h, kernel_w = 3, 3, 3 # Kernel size (Depth, Height, Width) + padding_d, padding_h, padding_w = 1, 1, 1 # Base padding (Depth, Height, Width) # Create dummy input data dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) @@ -300,11 +287,11 @@ def test_3d_conv(self): # Instantiate the module causal_conv_layer = WanCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(kernel_d, kernel_h, kernel_w), - padding=(padding_d, padding_h, padding_w), - rngs=rngs # Pass rngs for initialization + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization ) # --- Test Case 1: No Cache --- @@ -316,7 +303,7 @@ def test_3d_conv(self): assert output_with_cache.shape == (1, 10, 32, 32, 16) # --- Test Case 3: With Cache larger than padding --- - larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) @@ -335,9 +322,9 @@ def test_wan_residual(self): expected_output_shape = (batch, t, height, width, dim) wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, ) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) @@ -349,9 +336,9 @@ def test_wan_residual(self): expected_output_shape = (batch, t, height, width, out_dim) wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, ) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) @@ -365,11 +352,8 @@ def test_wan_attention(self): t = 1 height = 60 width = 90 - input_shape=(batch, t, height, width, dim) - wan_attention = WanAttentionBlock( - dim=dim, - rngs=rngs - ) + input_shape = (batch, t, height, width, dim) + wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) output = wan_attention(dummy_input) assert output.shape == input_shape @@ -383,9 +367,7 @@ def test_wan_midblock(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock( - dim=dim, rngs=rngs - ) + wan_midblock = WanMidBlock(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) output = wan_midblock(dummy_input) assert output.shape == input_shape @@ -400,13 +382,13 @@ def test_wan_decode(self): attn_scales = [] temperal_downsample = [False, True, True] wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, ) batch = 1 @@ -433,13 +415,13 @@ def test_wan_encode(self): attn_scales = [] temperal_downsample = [False, True, True] wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, ) batch = 1 channels = 3 @@ -451,5 +433,6 @@ def test_wan_encode(self): output = wan_vae.encode(input) assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py index 2da782b46..c29485118 100644 --- a/src/maxdiffusion/video_processor.py +++ b/src/maxdiffusion/video_processor.py @@ -23,91 +23,91 @@ class VideoProcessor(VaeImageProcessor): - r"""Simple video processor.""" + r"""Simple video processor.""" - def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: - r""" - Preprocesses input video(s). + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). - Args: - video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): - The input video. It can be one of the following: - * List of the PIL images. - * List of list of PIL images. - * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). - * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). - * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, - width)`). - * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). - * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, - num_channels)`. - * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, - width)`. - height (`int`, *optional*, defaults to `None`): - The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to - get default height. - width (`int`, *optional*`, defaults to `None`): - The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get - the default width. - """ - if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: - warnings.warn( - "Passing `video` as a list of 5d np.ndarray is deprecated." - "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", - FutureWarning, - ) - video = np.concatenate(video, axis=0) - if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: - warnings.warn( - "Passing `video` as a list of 5d torch.Tensor is deprecated." - "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", - FutureWarning, - ) - video = torch.cat(video, axis=0) + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) - # ensure the input is a list of videos: - # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) - # - if it is a single video, it is convereted to a list of one video. - if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: - video = list(video) - elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): - video = [video] - elif isinstance(video, list) and is_valid_image_imagelist(video[0]): - video = video - else: - raise ValueError( - "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" - ) + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) - video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) - # move the number of channels before the number of frames. - video = video.permute(0, 2, 1, 3, 4) + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) - return video + return video - def postprocess_video( - self, video: torch.Tensor, output_type: str = "np" - ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: - r""" - Converts a video tensor to a list of frames for export. + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. - Args: - video (`torch.Tensor`): The video as a tensor. - output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. - """ - batch_size = video.shape[0] - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = self.postprocess(batch_vid, output_type) - outputs.append(batch_output) + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) - if output_type == "np": - outputs = np.stack(outputs) - elif output_type == "pt": - outputs = torch.stack(outputs) - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - return outputs + return outputs From 089f8ac5a0ed00b56f53fce4226b6c2232e89573 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 29 Apr 2025 20:32:34 +0000 Subject: [PATCH 15/25] fix unit tests --- src/maxdiffusion/tests/wan_vae_test.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 98e49f130..fc3f1cb6d 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -221,19 +221,13 @@ def test_wan_upsample(self): in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 - dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + dummy_input = jnp.ones((batch_size * in_depth, in_height, in_width, in_channels)) upsample = WanUpsample(scale_factor=(2.0, 2.0)) # --- Test Case 1: depth > 1 --- output = upsample(dummy_input) - assert output.shape == (1, 10, 64, 64, 3) - - in_depth = 1 - dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) - # --- Test Case 1: depth == 1 --- - output = upsample(dummy_input) - assert output.shape == (1, 1, 64, 64, 3) + assert output.shape == (10, 64, 64, 3) def test_wan_resample(self): # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity @@ -260,13 +254,7 @@ def test_wan_resample(self): input_shape = (batch, t, h, w, dim) dummy_input = jnp.ones(input_shape) output = wan_resample(dummy_input) - assert output.shape == (batch, t, h // 2, h // 2, dim) - breakpoint() - - # --- Test Case 1: downsample3d --- - dim = 192 - input_shape = (1, dim, 1, 240, 360) - torch_wan_resample = WanResample(dim=dim, mode="downsample3d") + assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): key = jax.random.key(0) From 40d423d097efb3527e0d7778e863330128d08dc4 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 2 May 2025 16:07:00 +0000 Subject: [PATCH 16/25] e2e wan vae with weights loading. Still not fully working. --- src/maxdiffusion/configuration_utils.py | 3 +- src/maxdiffusion/generate_wan.py | 1 + src/maxdiffusion/models/flux/util.py | 21 +-- .../models/modeling_flax_pytorch_utils.py | 49 +++++- .../models/wan/autoencoder_kl_wan.py | 26 +-- src/maxdiffusion/models/wan/wan_utils.py | 78 +++++++++ src/maxdiffusion/tests/wan_vae_test.py | 16 +- src/maxdiffusion/utils/__init__.py | 3 +- src/maxdiffusion/utils/export_utils.py | 114 ++++++++++-- src/maxdiffusion/utils/import_utils.py | 23 +++ src/maxdiffusion/utils/loading_utils copy.py | 162 ++++++++++++++++++ src/maxdiffusion/utils/loading_utils.py | 88 +++++++++- 12 files changed, 527 insertions(+), 57 deletions(-) create mode 100644 src/maxdiffusion/models/wan/wan_utils.py create mode 100644 src/maxdiffusion/utils/loading_utils copy.py diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 8bcb6d80d..bfd420f72 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -464,7 +464,8 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove flax internal keys if hasattr(cls, "_flax_internal_args"): for arg in cls._flax_internal_args: - expected_keys.remove(arg) + if arg in expected_keys: + expected_keys.remove(arg) # 2. Remove attributes that cannot be expected from expected config attributes # remove keys to be ignored diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 3a79621d3..ce33c6806 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -26,6 +26,7 @@ from absl import app from transformers import AutoTokenizer, UMT5EncoderModel from maxdiffusion import pyconfig, max_logging +from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWan from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 362a39171..26371e4db 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -11,7 +11,11 @@ from jax import numpy as jnp from safetensors import safe_open -from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) +from ..modeling_flax_pytorch_utils import ( + rename_key, + rename_key_and_reshape_tensor, + torch2jax +) from maxdiffusion import max_logging @@ -32,21 +36,6 @@ class FluxParams: rngs: Array param_dtype: DTypeLike - -def torch2jax(torch_tensor: torch.Tensor) -> Array: - is_bfloat16 = torch_tensor.dtype == torch.bfloat16 - if is_bfloat16: - # upcast the tensor to fp32 - torch_tensor = torch_tensor.float() - - if torch.device.type != "cpu": - torch_tensor = torch_tensor.to("cpu") - - numpy_value = torch_tensor.numpy() - jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) - return jax_array - - @dataclass class ModelSpec: params: FluxParams diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 9552c69f1..74c7fce50 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -15,18 +15,57 @@ """ PyTorch - Flax general utilities.""" import re +import torch import jax import jax.numpy as jnp from flax.linen import Partitioned from flax.traverse_util import flatten_dict, unflatten_dict from flax.core.frozen_dict import unfreeze from jax.random import PRNGKey - +from chex import Array from ..utils import logging +from .. import max_logging logger = logging.get_logger(__name__) +def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): + """ + expected_pytree: dict - a pytree that comes from initializing the model. + new_pytree: dict - a pytree that has been created from pytorch weights. + """ + expected_pytree = flatten_dict(expected_pytree) + if len(expected_pytree.keys()) != len(new_pytree.keys()): + set1 = set(expected_pytree.keys()) + set2 = set(new_pytree.keys()) + missing_keys = set1 ^ set2 + max_logging.log(f"missing keys : {missing_keys}") + for key in expected_pytree.keys(): + if key in new_pytree.keys(): + try: + expected_pytree_shape = expected_pytree[key].shape + except Exception: + expected_pytree_shape = expected_pytree[key].value.shape + if expected_pytree_shape != new_pytree[key].shape: + max_logging.log(f"shape mismatch for {key}") + max_logging.log( + f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}" + ) + else: + max_logging.log(f"key: {key} not found...") + +def torch2jax(torch_tensor: torch.Tensor) -> Array: + is_bfloat16 = torch_tensor.dtype == torch.bfloat16 + if is_bfloat16: + # upcast the tensor to fp32 + torch_tensor = torch_tensor.float() + + if torch.device.type != "cpu": + torch_tensor = torch_tensor.to("cpu") + + numpy_value = torch_tensor.numpy() + jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) + return jax_array def rename_key(key): regex = r"\w+[.]\d+" @@ -93,6 +132,12 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor + + # 3d conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 5: + pt_tensor = pt_tensor.transpose(2, 3, 4, 1, 0) + return renamed_pt_tuple_key, pt_tensor # linear layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) @@ -103,6 +148,8 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic # old PyTorch layer norm weight renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) if pt_tuple_key[-1] == "gamma": + renamed_pt_tuple_key = pt_tuple_key + pt_tensor = pt_tensor.flatten() return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm bias diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 2cf43f69c..ceb7bbe2f 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -200,8 +200,6 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", ): - kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") - stride = _canonicalize_tuple(stride, 3, "stride") self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) def __call__(self, x): @@ -233,7 +231,7 @@ def __init__( nnx.Conv( dim, dim // 2, - kernel_size=(1, 3, 3), + kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, @@ -241,11 +239,11 @@ def __init__( ) elif mode == "upsample3d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), nnx.Conv( dim, dim // 2, - kernel_size=(1, 3, 3), + kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, @@ -259,11 +257,9 @@ def __init__( padding=(1, 0, 0), ) elif mode == "downsample2d": - # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) elif mode == "downsample3d": - # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) self.time_conv = WanCausalConv3d( rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) ) @@ -334,7 +330,6 @@ def __init__( self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1) self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) - self.dropout = nnx.Dropout(dropout, rngs=rngs) self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1) self.conv_shortcut = ( WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1) @@ -363,7 +358,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm2(x) x = self.nonlinearity(x) - x = self.dropout(x) if feat_cache is not None: idx = feat_idx[0] @@ -384,8 +378,8 @@ class WanAttentionBlock(nnx.Module): def __init__(self, dim: int, rngs: nnx.Rngs): self.dim = dim self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) - self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=1, rngs=rngs) - self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=1, rngs=rngs) + self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs) + self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs) def __call__(self, x: jax.Array): batch_size, time, height, width, channels = x.shape @@ -801,8 +795,6 @@ def _encode(self, x: jax.Array): x = jnp.transpose(x, (0, 2, 3, 4, 1)) assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" - # self.clear_cache() - t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): @@ -854,8 +846,8 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: if z.shape[-1] != self.z_dim: # reshape channel last for JAX - x = jnp.transpose(x, (0, 2, 3, 4, 1)) - assert x.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {x.shape}" + z = jnp.transpose(z, (0, 2, 3, 4, 1)) + assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}" decoded = self._decode(z).sample if not return_dict: return (decoded,) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py new file mode 100644 index 000000000..1ff994c6d --- /dev/null +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -0,0 +1,78 @@ +import jax +import jax.numpy as jnp +from maxdiffusion import max_logging +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from flax.traverse_util import flatten_dict, unflatten_dict +from ..modeling_flax_pytorch_utils import ( + rename_key, + rename_key_and_reshape_tensor, + torch2jax, + validate_flax_state_dict +) + +def _tuple_str_to_int(in_tuple): + out_list = [] + for item in in_tuple: + try: + out_list.append(int(item)) + except: + out_list.append(item) + return tuple(out_list) + + +def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors") + #breakpoint() + max_logging.log(f"Load and port Wan 2.1 VAE on {device}") + + if ckpt_path is not None: + tensors = {} + with safe_open(ckpt_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + # Order matters + renamed_pt_key = renamed_pt_key.replace("up_blocks_", "up_blocks.") + renamed_pt_key = renamed_pt_key.replace("mid_block_", "mid_block.") + renamed_pt_key = renamed_pt_key.replace("down_blocks_", "down_blocks.") + + renamed_pt_key = renamed_pt_key.replace("conv_in.bias", "conv_in.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv_in.weight", "conv_in.conv.weight") + renamed_pt_key = renamed_pt_key.replace("conv_out.bias", "conv_out.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv_out.weight", "conv_out.conv.weight") + renamed_pt_key = renamed_pt_key.replace("attentions_", "attentions.") + renamed_pt_key = renamed_pt_key.replace("resnets_", "resnets.") + renamed_pt_key = renamed_pt_key.replace("upsamplers_", "upsamplers.") + renamed_pt_key = renamed_pt_key.replace("resample_", "resample.") + renamed_pt_key = renamed_pt_key.replace("conv1.bias", "conv1.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv1.weight", "conv1.conv.weight") + renamed_pt_key = renamed_pt_key.replace("conv2.bias", "conv2.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv2.weight", "conv2.conv.weight") + renamed_pt_key = renamed_pt_key.replace("time_conv.bias", "time_conv.conv.bias") + renamed_pt_key = renamed_pt_key.replace("time_conv.weight", "time_conv.conv.weight") + renamed_pt_key = renamed_pt_key.replace("quant_conv", "quant_conv.conv") + renamed_pt_key = renamed_pt_key.replace("conv_shortcut", "conv_shortcut.conv") + if "decoder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("resample.1.bias", "resample.layers.1.bias") + renamed_pt_key = renamed_pt_key.replace("resample.1.weight", "resample.layers.1.weight") + if "encoder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("resample.1", "resample.conv") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + else: + raise FileNotFoundError(f"Path {ckpt_path} was not found") + + return flax_state_dict \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index fc3f1cb6d..0ae3f3a26 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -29,7 +29,6 @@ WanCausalConv3d, WanUpsample, AutoencoderKLWan, - WanEncoder3d, WanMidBlock, WanResidualBlock, WanRMS_norm, @@ -37,6 +36,7 @@ ZeroPaddedConv2D, WanAttentionBlock, ) +from ..models.wan.wan_utils import load_wan_vae CACHE_T = 2 @@ -421,6 +421,20 @@ def test_wan_encode(self): output = wan_vae.encode(input) assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + # def test_load_checkpoint(self): + # pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + # key = jax.random.key(0) + # rngs = nnx.Rngs(key) + # wan_vae = AutoencoderKLWan.from_config( + # pretrained_model_name_or_path, + # subfolder="vae", + # rngs=rngs + # ) + # graphdef, state = nnx.split(wan_vae) + # params = state.to_pure_dict() + # # This replaces random params with the model. + # params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + if __name__ == "__main__": absltest.main() diff --git a/src/maxdiffusion/utils/__init__.py b/src/maxdiffusion/utils/__init__.py index 199b7c156..256934b3b 100644 --- a/src/maxdiffusion/utils/__init__.py +++ b/src/maxdiffusion/utils/__init__.py @@ -83,7 +83,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image +from .loading_utils import load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( @@ -103,7 +103,6 @@ convert_unet_state_dict_to_peft, ) - logger = get_logger(__name__) diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index cd31a7baa..461922e94 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -3,14 +3,14 @@ import struct import tempfile from contextlib import contextmanager -from typing import List +from typing import List, Optional, Union import numpy as np import PIL.Image import PIL.ImageOps -from .import_utils import BACKENDS_MAPPING, is_opencv_available +from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available from .logging import get_logger @@ -110,19 +110,97 @@ def export_to_obj(mesh, output_obj_path: str = None): with open(output_obj_path, "w") as f: f.writelines("\n".join(combined_data)) - -def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: - if is_opencv_available(): - import cv2 - else: - raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - h, w, c = video_frames[0].shape - video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) - for i in range(len(video_frames)): - img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) - video_writer.write(img) - return output_video_path +def _legacy_export_to_video( + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 +): + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + + return output_video_path + +def export_to_video( + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], + output_video_path: str = None, + fps: int = 10, + quality: float = 5.0, + bitrate: Optional[int] = None, + macro_block_size: Optional[int] = 16, +) -> str: + """ + quality: + Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to + prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. + Specifying a fixed bitrate using `bitrate` disables this parameter. + + bitrate: + Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. + Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter + rather than specifiying a fixed bitrate with this parameter. + + macro_block_size: + Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number + imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs + are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic + feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some + codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. + """ + # TODO: Dhruv. Remove by Diffusers release 0.33.0 + # Added to prevent breaking existing code + if not is_imageio_available(): + logger.warning( + ( + "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" + "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" + "Support for the OpenCV backend will be deprecated in a future Diffusers version" + ) + ) + return _legacy_export_to_video(video_frames, output_video_path, fps) + + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + ( + "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" + "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" + ) + ) + + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + + with imageio.get_writer( + output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size + ) as writer: + for frame in video_frames: + writer.append_data(frame) + + return output_video_path diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index 99d88336a..224bad16c 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -51,6 +51,19 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -105,6 +118,7 @@ except importlib_metadata.PackageNotFoundError: _transformers_available = False +_imageio_available, _imageio_version = _is_package_available("imageio") _inflect_available = importlib.util.find_spec("inflect") is not None try: @@ -284,6 +298,8 @@ except importlib_metadata.PackageNotFoundError: _peft_available = False +def is_imageio_available(): + return _imageio_available def is_torch_available(): return _torch_available @@ -486,6 +502,11 @@ def is_peft_available(): {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` """ +# docstyle-ignore +IMAGEIO_IMPORT_ERROR = """ +{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -506,6 +527,7 @@ def is_peft_available(): ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), ] ) @@ -705,3 +727,4 @@ def _get_module(self, module_name: str): def __reduce__(self): return (self.__class__, (self._name, self.__file__, self._import_structure)) + diff --git a/src/maxdiffusion/utils/loading_utils copy.py b/src/maxdiffusion/utils/loading_utils copy.py new file mode 100644 index 000000000..fd66aaa4d --- /dev/null +++ b/src/maxdiffusion/utils/loading_utils copy.py @@ -0,0 +1,162 @@ +import os +import tempfile +from typing import Any, Callable, List, Optional, Tuple, Union +from urllib.parse import unquote, urlparse + +import PIL.Image +import PIL.ImageOps +import requests + +from .import_utils import BACKENDS_MAPPING, is_imageio_available + + +def load_image( + image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None +) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): + A conversion method to apply to the image after loading it. When set to `None` the image will be converted + "RGB". + + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." + ) + + image = PIL.ImageOps.exif_transpose(image) + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + + return image + + +def load_video( + video: str, + convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, +) -> List[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images + + +# Taken from `transformers`. +def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + return module, tensor_name + + +def get_submodule_by_name(root_module, module_path: str): + current = root_module + parts = module_path.split(".") + for part in parts: + if part.isdigit(): + idx = int(part) + current = current[idx] # e.g., for nn.ModuleList or nn.Sequential + else: + current = getattr(current, part) + return current diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 07c08e726..77caa39e0 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -1,9 +1,12 @@ import os -from typing import Union +from typing import Any, Callable, List, Optional, Tuple, Union import PIL.Image import PIL.ImageOps import requests +import tempfile +from urllib.parse import unquote, urlparse +from .import_utils import BACKENDS_MAPPING, is_imageio_available def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: @@ -33,3 +36,86 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image + +def load_video( + video: str, + convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, +) -> List[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images \ No newline at end of file From 34ebdbea48b2f6f874d1106096a5c05f174d691d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 7 May 2025 18:13:23 +0000 Subject: [PATCH 17/25] debug statements --- .../models/wan/autoencoder_kl_wan.py | 118 +++++++++++++++--- src/maxdiffusion/models/wan/wan_utils.py | 1 - src/maxdiffusion/tests/wan_vae_test.py | 15 +-- 3 files changed, 107 insertions(+), 27 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index ceb7bbe2f..06bd37aaa 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -23,7 +23,7 @@ from ..modeling_flax_utils import FlaxModelMixin from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) - +import numpy as np BlockSizes = common_types.BlockSizes CACHE_T = 2 @@ -93,11 +93,16 @@ def __init__( rngs=rngs, ) - def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Array: + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: + print("wanCausalConv3d, x min: ", np.min(x)) + print("wanCausalConv3d, x max: ", np.max(x)) current_padding = list(self._causal_padding) # Mutable copy padding_needed = self._depth_padding_before if cache_x is not None and padding_needed > 0: + print("WanCausalConv3d, cache.shape: ", cache_x.shape) + print("wanCausalConv3d, cache_x min: ", np.min(cache_x)) + print("wanCausalConv3d, cache_x max: ", np.max(cache_x)) # Ensure cache has same spatial/channel dims, potentially different depth assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" cache_len = cache_x.shape[1] @@ -105,21 +110,34 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Arr padding_needed -= cache_len if padding_needed < 0: + print("wanCausanConv3d, padding_needed < 0") # Cache longer than needed padding, trim from start x = x[:, -padding_needed:, ...] current_padding[1] = (0, 0) # No explicit padding needed now else: # Update depth padding needed + print("wanCausanConv3d, padding_needed > 0") current_padding[1] = (padding_needed, 0) # Apply padding if any dimension requires it padding_to_apply = tuple(current_padding) + print("WanCausalConv3d, before padding x shape: ", x.shape) if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): + print("WanCausalConv3d, applying padding") x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: + print("WanCausalConv3d, NOT applying padding") x_padded = x + print("WanCausalConv3d, x shape: ", x_padded.shape) + print("wanCausalConv3d, x min: ", np.min(x_padded)) + print("wanCausalConv3d, x max: ", np.max(x_padded)) + # if idx == 12: + # breakpoint() out = self.conv(x_padded) + print("WanCausalConv3d, after conv, x shape: ", out.shape) + print("wanCausalConv3d, x min: ", np.min(out)) + print("wanCausalConv3d, x max: ", np.max(out)) return out @@ -346,11 +364,13 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] + print("Before conv1, idx: ", idx) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - - x = self.conv1(x, feat_cache[idx]) + x = self.conv1(x, feat_cache[idx], idx) + # if idx == 4: + # breakpoint() feat_cache[idx] = cache_x feat_idx[0] += 1 else: @@ -358,19 +378,34 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm2(x) x = self.nonlinearity(x) + idx = feat_idx[0] + # if idx == 4: + # breakpoint() if feat_cache is not None: idx = feat_idx[0] + print("Residual block, idx: ", idx) + # if idx == 14: + # breakpoint() + print("cache_x min: ", np.min(cache_x)) + print("cache_x max: ", np.max(cache_x)) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + print("cache_x min: ", np.min(cache_x)) + print("cache_x max: ", np.max(cache_x)) + #breakpoint() x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv2(x) - - return x + h + print("before conv shortcut add: x min", np.min(x)) + print("before conv shortcut add: x max", np.max(x)) + x = x + h + print("after conv shortcut add: x min: ", np.min(x)) + print("after conv shortcut add: x max: ", np.max(x)) + return x class WanAttentionBlock(nnx.Module): @@ -382,26 +417,51 @@ def __init__(self, dim: int, rngs: nnx.Rngs): self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs) def __call__(self, x: jax.Array): - batch_size, time, height, width, channels = x.shape + identity = x + batch_size, time, height, width, channels = x.shape x = x.reshape(batch_size * time, height, width, channels) x = self.norm(x) qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - - qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + #breakpoint() + #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - q, k, v = jnp.split(qkv, 3, axis=-1) - - x = jax.nn.dot_product_attention(q, k, v) + print("qkv min: ", np.min(qkv)) + print("qkv max: ", np.max(qkv)) + #q, k, v = jnp.split(qkv, 3, axis=-1) + q, k, v = jnp.split(qkv, 3, axis=-2) + print("q min: ", np.min(q)) + print("q max: ", np.max(q)) + print("k min: ", np.min(k)) + print("k min: ", np.max(k)) + print("v min: ", np.min(v)) + print("v min: ", np.max(v)) + #breakpoint() + q = jnp.transpose(q, (0, 1, 3, 2)) + k = jnp.transpose(k, (0, 1, 3, 2)) + v = jnp.transpose(v, (0, 1, 3, 2)) + import torch + import torch.nn.functional as F + q = torch.tensor(np.array(q, dtype=np.float32)) + k = torch.tensor(np.array(k, dtype=np.float32)) + v = torch.tensor(np.array(v, dtype=np.float32)) + #x = jax.nn.dot_product_attention(q, k, v) + x = F.scaled_dot_product_attention(q, k, v) + print("attn min: ", torch.min(x)) + print("attn max: ", torch.max(x)) + #breakpoint() + x = jnp.array(x.detach().numpy()) x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) # output projection x = self.proj(x) - + #breakpoint() # Reshape back x = x.reshape(batch_size, time, height, width, channels) + #breakpoint() return x + identity @@ -419,11 +479,20 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + print("WanMidblock...") x = self.resnets[0](x, feat_cache, feat_idx) + print("WanMidBlock resnets[0], x min: ", np.min(x)) + print("WanMidBlock resnets[0], x max: ", np.max(x)) for attn, resnet in zip(self.attentions, self.resnets[1:]): + print("WanMidBlock, for loop, attn len: ", len(self.attentions)) + print("WanMidBlock, for loop, resnets len: ", len(self.resnets)) if attn is not None: x = attn(x) + print("WanMidBlock attn[0], x min: ", np.min(x)) + print("WanMidBlock attn[0], x max: ", np.max(x)) x = resnet(x, feat_cache, feat_idx) + print("WanMidBlock resnets[i], x min: ", np.min(x)) + print("WanMidBlock resnets[i], x max: ", np.max(x)) return x @@ -589,7 +658,7 @@ def __init__( self, rngs: nnx.Rngs, dim: int = 128, - z_dim: int = 128, + z_dim: int = 4, dim_mult: List[int] = [1, 2, 4, 4], num_res_blocks: int = 2, attn_scales=List[float], @@ -662,7 +731,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): ## middle x = self.mid_block(x, feat_cache, feat_idx) - + #breakpoint() ## upsamples for up_block in self.up_blocks: x = up_block(x, feat_cache, feat_idx) @@ -810,7 +879,6 @@ def _encode(self, x: jax.Array): mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] enc = jnp.concatenate([mu, logvar], axis=-1) self.clear_cache() - # return enc return enc def encode( @@ -833,10 +901,22 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - out = jnp.concatenate([out, out_], axis=1) - - out = jnp.clip(out, a_min=-1.0, a_max=1.0) + print("out_.shape: ", out_.shape) + print("out_ min: ", np.min(out_)) + print("out_ max: ", np.max(out_)) + print("out.shape: ", out.shape) + print("out min: ", np.min(out)) + print("out max: ", np.max(out)) + for i in range(len(self._feat_map)): + if isinstance(self._feat_map[i], jax.Array): + print("i: ", i) + print("min: ", np.min(self._feat_map[i])) + print("max: ", np.max(self._feat_map[i])) + else: + print(f"feat_map[{i}] : {self._feat_map[i]}") + # breakpoint() + out = jnp.clip(out, min=-1.0, max=1.0) self.clear_cache() if not return_dict: return (out,) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 1ff994c6d..2a8517954 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -26,7 +26,6 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: with jax.default_device(device): if hf_download: ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors") - #breakpoint() max_logging.log(f"Load and port Wan 2.1 VAE on {device}") if ckpt_path is not None: diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 0ae3f3a26..e71ddbde2 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -40,7 +40,6 @@ CACHE_T = 2 - class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer. @@ -92,16 +91,18 @@ def __init__(self, dim: int, mode: str) -> None: WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) ) elif mode == "upsample3d": - self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + # self.resample = nn.Sequential( + # WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + # ) + # self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + raise Exception("downsample3d not supported") elif mode == "downsample2d": self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + raise Exception("downsample3d not supported") + #self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + #self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: self.resample = nn.Identity() From 04f4909982697df6cc74c1a80a4dbea28ca58708 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 7 May 2025 21:39:55 +0000 Subject: [PATCH 18/25] solves distored decoded video. Now video is jittery, but frames are ok. --- .../models/wan/autoencoder_kl_wan.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 06bd37aaa..296e58518 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -286,7 +286,7 @@ def __init__( def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: # Input x: (N, D, H, W, C), assume C = self.dim - n, d, h, w, c = x.shape + b, t, h, w, c = x.shape assert c == self.dim if self.mode == "upsample3d": @@ -308,14 +308,14 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: x = self.time_conv(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 - x = x.reshape(n, 2, d, h, w, c) - x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2) - x = x.reshape(n, d * 2, h, w, c) - d = x.shape[1] - x = x.reshape(n * d, h, w, c) + x = x.reshape(b, t, h, w, 2, c) + x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1) + x = x.reshape(b, t * 2, h, w, c) + t = x.shape[1] + x = x.reshape(b * t, h, w, c) x = self.resample(x) h_new, w_new, c_new = x.shape[1:] - x = x.reshape(n, d, h_new, w_new, c_new) + x = x.reshape(b, t, h_new, w_new, c_new) if self.mode == "downsample3d": if feat_cache is not None: @@ -425,7 +425,6 @@ def __call__(self, x: jax.Array): x = self.norm(x) qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - #breakpoint() #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) @@ -439,21 +438,10 @@ def __call__(self, x: jax.Array): print("k min: ", np.max(k)) print("v min: ", np.min(v)) print("v min: ", np.max(v)) - #breakpoint() q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) v = jnp.transpose(v, (0, 1, 3, 2)) - import torch - import torch.nn.functional as F - q = torch.tensor(np.array(q, dtype=np.float32)) - k = torch.tensor(np.array(k, dtype=np.float32)) - v = torch.tensor(np.array(v, dtype=np.float32)) - #x = jax.nn.dot_product_attention(q, k, v) - x = F.scaled_dot_product_attention(q, k, v) - print("attn min: ", torch.min(x)) - print("attn max: ", torch.max(x)) - #breakpoint() - x = jnp.array(x.detach().numpy()) + x = jax.nn.dot_product_attention(q, k, v) x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) # output projection @@ -696,7 +684,7 @@ def __init__( upsample_mode = None if i != len(dim_mult) - 1: upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - # Crete and add the upsampling block + # Create and add the upsampling block up_block = WanUpBlock( in_dim=in_dim, out_dim=out_dim, @@ -731,7 +719,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): ## middle x = self.mid_block(x, feat_cache, feat_idx) - #breakpoint() ## upsamples for up_block in self.up_blocks: x = up_block(x, feat_cache, feat_idx) From 66146b9dcf63630b16cc10ba1a341c4fa2ee455f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 01:00:53 +0000 Subject: [PATCH 19/25] fixes jittery decoder frames in vae. --- .../models/wan/autoencoder_kl_wan.py | 86 +++---------------- 1 file changed, 14 insertions(+), 72 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 296e58518..f9603e51e 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -94,15 +94,10 @@ def __init__( ) def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: - print("wanCausalConv3d, x min: ", np.min(x)) - print("wanCausalConv3d, x max: ", np.max(x)) current_padding = list(self._causal_padding) # Mutable copy padding_needed = self._depth_padding_before if cache_x is not None and padding_needed > 0: - print("WanCausalConv3d, cache.shape: ", cache_x.shape) - print("wanCausalConv3d, cache_x min: ", np.min(cache_x)) - print("wanCausalConv3d, cache_x max: ", np.max(cache_x)) # Ensure cache has same spatial/channel dims, potentially different depth assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" cache_len = cache_x.shape[1] @@ -110,34 +105,20 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> padding_needed -= cache_len if padding_needed < 0: - print("wanCausanConv3d, padding_needed < 0") # Cache longer than needed padding, trim from start x = x[:, -padding_needed:, ...] current_padding[1] = (0, 0) # No explicit padding needed now else: # Update depth padding needed - print("wanCausanConv3d, padding_needed > 0") current_padding[1] = (padding_needed, 0) # Apply padding if any dimension requires it padding_to_apply = tuple(current_padding) - print("WanCausalConv3d, before padding x shape: ", x.shape) if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): - print("WanCausalConv3d, applying padding") x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: - print("WanCausalConv3d, NOT applying padding") x_padded = x - - print("WanCausalConv3d, x shape: ", x_padded.shape) - print("wanCausalConv3d, x min: ", np.min(x_padded)) - print("wanCausalConv3d, x max: ", np.max(x_padded)) - # if idx == 12: - # breakpoint() out = self.conv(x_padded) - print("WanCausalConv3d, after conv, x shape: ", out.shape) - print("wanCausalConv3d, x min: ", np.min(out)) - print("wanCausalConv3d, x max: ", np.max(out)) return out @@ -300,8 +281,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": # cache last frame of last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], dim=1) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1) if feat_cache[idx] == "Rep": x = self.time_conv(x) else: @@ -364,13 +345,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - print("Before conv1, idx: ", idx) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv1(x, feat_cache[idx], idx) - # if idx == 4: - # breakpoint() feat_cache[idx] = cache_x feat_idx[0] += 1 else: @@ -379,32 +357,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm2(x) x = self.nonlinearity(x) idx = feat_idx[0] - # if idx == 4: - # breakpoint() if feat_cache is not None: idx = feat_idx[0] - print("Residual block, idx: ", idx) - # if idx == 14: - # breakpoint() - print("cache_x min: ", np.min(cache_x)) - print("cache_x max: ", np.max(cache_x)) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - print("cache_x min: ", np.min(cache_x)) - print("cache_x max: ", np.max(cache_x)) - #breakpoint() x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv2(x) - print("before conv shortcut add: x min", np.min(x)) - print("before conv shortcut add: x max", np.max(x)) x = x + h - print("after conv shortcut add: x min: ", np.min(x)) - print("after conv shortcut add: x max: ", np.max(x)) return x @@ -428,16 +392,8 @@ def __call__(self, x: jax.Array): #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - print("qkv min: ", np.min(qkv)) - print("qkv max: ", np.max(qkv)) #q, k, v = jnp.split(qkv, 3, axis=-1) q, k, v = jnp.split(qkv, 3, axis=-2) - print("q min: ", np.min(q)) - print("q max: ", np.max(q)) - print("k min: ", np.min(k)) - print("k min: ", np.max(k)) - print("v min: ", np.min(v)) - print("v min: ", np.max(v)) q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) v = jnp.transpose(v, (0, 1, 3, 2)) @@ -446,10 +402,8 @@ def __call__(self, x: jax.Array): # output projection x = self.proj(x) - #breakpoint() # Reshape back x = x.reshape(batch_size, time, height, width, channels) - #breakpoint() return x + identity @@ -467,20 +421,11 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - print("WanMidblock...") x = self.resnets[0](x, feat_cache, feat_idx) - print("WanMidBlock resnets[0], x min: ", np.min(x)) - print("WanMidBlock resnets[0], x max: ", np.max(x)) for attn, resnet in zip(self.attentions, self.resnets[1:]): - print("WanMidBlock, for loop, attn len: ", len(self.attentions)) - print("WanMidBlock, for loop, resnets len: ", len(self.resnets)) if attn is not None: x = attn(x) - print("WanMidBlock attn[0], x min: ", np.min(x)) - print("WanMidBlock attn[0], x max: ", np.max(x)) x = resnet(x, feat_cache, feat_idx) - print("WanMidBlock resnets[i], x min: ", np.min(x)) - print("WanMidBlock resnets[i], x max: ", np.max(x)) return x @@ -888,21 +833,18 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - out = jnp.concatenate([out, out_], axis=1) - print("out_.shape: ", out_.shape) - print("out_ min: ", np.min(out_)) - print("out_ max: ", np.max(out_)) - print("out.shape: ", out.shape) - print("out min: ", np.min(out)) - print("out max: ", np.max(out)) - for i in range(len(self._feat_map)): - if isinstance(self._feat_map[i], jax.Array): - print("i: ", i) - print("min: ", np.min(self._feat_map[i])) - print("max: ", np.max(self._feat_map[i])) - else: - print(f"feat_map[{i}] : {self._feat_map[i]}") - # breakpoint() + + # This is to bypass an issue where frame[1] should be frame[2] and vise versa. + # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. + # Most likely due to an incorrect reshaping in the decoder. + fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] + if len(fm1.shape) == 4: + fm1 = jnp.expand_dims(fm1, axis=0) + fm2 = jnp.expand_dims(fm2, axis=0) + fm3 = jnp.expand_dims(fm3, axis=0) + fm4 = jnp.expand_dims(fm4, axis=0) + + out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) out = jnp.clip(out, min=-1.0, max=1.0) self.clear_cache() if not return_dict: From 4245b2419e98441c06f9de6abd4794f56cccfb57 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 16:01:46 +0000 Subject: [PATCH 20/25] cleanup unused code. --- .../models/wan/autoencoder_kl_wan.py | 25 +------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index f9603e51e..e7bbc1e23 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -19,11 +19,10 @@ import jax import jax.numpy as jnp from flax import nnx -from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...configuration_utils import ConfigMixin from ..modeling_flax_utils import FlaxModelMixin from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) -import numpy as np BlockSizes = common_types.BlockSizes CACHE_T = 2 @@ -60,13 +59,6 @@ def __init__( stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, use_bias: bool = True, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", ): self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") self.stride = _canonicalize_tuple(stride, 3, "stride") @@ -191,13 +183,6 @@ def __init__( rngs: nnx.Rngs, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", ): self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) @@ -212,13 +197,6 @@ def __init__( dim: int, mode: str, rngs: nnx.Rngs, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", ): self.dim = dim self.mode = mode @@ -548,7 +526,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_in(x) - # (1, 1, 480, 720, 96) for layer in self.down_blocks: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) From c0ba5c1bc44448d005d33b8fdbd1750aa5b00958 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:08:47 +0000 Subject: [PATCH 21/25] linting --- end_to_end/tpu/eval_assert.py | 3 +- src/maxdiffusion/generate_wan.py | 226 -------------- src/maxdiffusion/models/flux/util.py | 7 +- .../models/modeling_flax_pytorch_utils.py | 5 +- .../models/wan/autoencoder_kl_wan.py | 95 +++--- .../wan/transformers/transformer_flux_wan.py | 287 ------------------ .../transformers/transformer_flux_wan_nnx.py | 69 ----- src/maxdiffusion/models/wan/wan_utils.py | 16 +- src/maxdiffusion/tests/wan_vae_test.py | 82 +++-- src/maxdiffusion/utils/export_utils.py | 144 ++++----- src/maxdiffusion/utils/import_utils.py | 23 +- src/maxdiffusion/utils/loading_utils copy.py | 162 ---------- src/maxdiffusion/utils/loading_utils.py | 153 +++++----- 13 files changed, 283 insertions(+), 989 deletions(-) delete mode 100644 src/maxdiffusion/generate_wan.py delete mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py delete mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py delete mode 100644 src/maxdiffusion/utils/loading_utils copy.py diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py index 07b5585a6..20fd0d8ae 100644 --- a/end_to_end/tpu/eval_assert.py +++ b/end_to_end/tpu/eval_assert.py @@ -22,7 +22,6 @@ """ - # pylint: skip-file """Reads and asserts over target values""" from absl import app @@ -47,7 +46,7 @@ def test_final_loss(metrics_file, target_loss, num_samples_str="10"): target_loss = float(target_loss) num_samples = int(num_samples_str) with open(metrics_file, "r", encoding="utf8") as _: - last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples) + last_n_data = get_last_n_data(metrics_file, "learning/loss", num_samples) avg_last_n_data = sum(last_n_data) / len(last_n_data) print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}") print(f"Target loss is {target_loss}") diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py deleted file mode 100644 index ce33c6806..000000000 --- a/src/maxdiffusion/generate_wan.py +++ /dev/null @@ -1,226 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import html -from typing import Callable, List, Union, Sequence, Optional -import time -import torch -import ftfy -import regex as re -import jax -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P -from flax import nnx -from absl import app -from transformers import AutoTokenizer, UMT5EncoderModel -from maxdiffusion import pyconfig, max_logging -from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWan -from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel -from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline - -from maxdiffusion.max_utils import ( - device_put_replicated, - get_memory_allocations, - create_device_mesh, - get_flash_block_sizes, - get_precision, - setup_initial_state, -) - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text - - -def _get_t5_prompt_embeds( - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - batch_size = len(prompt) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - - prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - -def encode_prompt( - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, - num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - """ - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - return prompt_embeds, negative_prompt_embeds - - -def run(config): - max_logging.log("Wan 2.1 inference script") - - rng = jax.random.key(config.seed) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - global_batch_size = config.per_device_batch_size * jax.local_device_count() - - tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype - ) - text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder", - ) - s0 = time.perf_counter() - prompt_embeds, negative_prompt_embeds = encode_prompt( - tokenizer=tokenizer, text_encoder=text_encoder, prompt=config.prompt, negative_prompt=config.negative_prompt - ) - max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") - - # pipeline, params = WanPipeline.from_pretrained( - # config.pretrained_model_name_or_path, - # #vae=None, - # #transformer=None - # ) - # breakpoint() - - pipeline, params = WanPipeline.from_pretrained(config.pretrained_model_name_or_path, vae=None, transformer=None) - - # wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) - - -def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) - - -if __name__ == "__main__": - app.run(main) diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 26371e4db..504b71e9f 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -11,11 +11,7 @@ from jax import numpy as jnp from safetensors import safe_open -from ..modeling_flax_pytorch_utils import ( - rename_key, - rename_key_and_reshape_tensor, - torch2jax -) +from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax) from maxdiffusion import max_logging @@ -36,6 +32,7 @@ class FluxParams: rngs: Array param_dtype: DTypeLike + @dataclass class ModelSpec: params: FluxParams diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 74c7fce50..d6a448f98 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -29,6 +29,7 @@ logger = logging.get_logger(__name__) + def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): """ expected_pytree: dict - a pytree that comes from initializing the model. @@ -54,6 +55,7 @@ def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): else: max_logging.log(f"key: {key} not found...") + def torch2jax(torch_tensor: torch.Tensor) -> Array: is_bfloat16 = torch_tensor.dtype == torch.bfloat16 if is_bfloat16: @@ -67,6 +69,7 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array: jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) return jax_array + def rename_key(key): regex = r"\w+[.]\d+" pats = re.findall(regex, key) @@ -132,7 +135,7 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor - + # 3d conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 5: diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index e7bbc1e23..8325c3707 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -23,6 +23,7 @@ from ..modeling_flax_utils import FlaxModelMixin from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) + BlockSizes = common_types.BlockSizes CACHE_T = 2 @@ -367,10 +368,10 @@ def __call__(self, x: jax.Array): x = self.norm(x) qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - #q, k, v = jnp.split(qkv, 3, axis=-1) + # q, k, v = jnp.split(qkv, 3, axis=-1) q, k, v = jnp.split(qkv, 3, axis=-2) q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) @@ -662,6 +663,32 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): return x +class AutoencoderKLWanCache: + + def __init__(self, module): + self.module = module + self.clear_cache() + + def clear_cache(self): + """Resets cache dictionaries and indices""" + + def _count_conv3d(module): + count = 0 + node_types = nnx.graph.iter_graph([module]) + for _, value in node_types: + if isinstance(value, WanCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.module.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.module.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): def __init__( @@ -745,29 +772,9 @@ def __init__( temperal_upsample=self.temporal_upsample, dropout=dropout, ) - self.clear_cache() - def clear_cache(self): - """Resets cache dictionaries and indices""" - - def _count_conv3d(module): - count = 0 - node_types = nnx.graph.iter_graph([module]) - for path, value in node_types: - if isinstance(value, WanCausalConv3d): - count += 1 - return count - - self._conv_num = _count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode - self._enc_conv_num = _count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num - - def _encode(self, x: jax.Array): - self.clear_cache() + def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): + feat_cache.clear_cache() if x.shape[-1] != 3: # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) @@ -776,42 +783,46 @@ def _encode(self, x: jax.Array): t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): - self._enc_conv_idx = [0] + feat_cache._enc_conv_idx = [0] if i == 0: - out = self.encoder(x[:, :1, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + out = self.encoder(x[:, :1, :, :, :], feat_cache=feat_cache._enc_feat_map, feat_idx=feat_cache._enc_conv_idx) else: out_ = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], + feat_cache=feat_cache._enc_feat_map, + feat_idx=feat_cache._enc_conv_idx, ) out = jnp.concatenate([out, out_], axis=1) enc = self.quant_conv(out) mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] enc = jnp.concatenate([mu, logvar], axis=-1) - self.clear_cache() + feat_cache.clear_cache() return enc def encode( - self, x: jax.Array, return_dict: bool = True + self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """Encode video into latent distribution.""" - h = self._encode(x) + h = self._encode(x, feat_cache) posterior = FlaxDiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: - self.clear_cache() + def _decode( + self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True + ) -> Union[FlaxDecoderOutput, jax.Array]: + feat_cache.clear_cache() iter_ = z.shape[1] x = self.post_quant_conv(z) for i in range(iter_): - self._conv_idx = [0] + feat_cache._conv_idx = [0] if i == 0: - out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) else: - out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - - # This is to bypass an issue where frame[1] should be frame[2] and vise versa. + out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + + # This is to bypass an issue where frame[1] should be frame[2] and vise versa. # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. # Most likely due to an incorrect reshaping in the decoder. fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] @@ -820,21 +831,23 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu fm2 = jnp.expand_dims(fm2, axis=0) fm3 = jnp.expand_dims(fm3, axis=0) fm4 = jnp.expand_dims(fm4, axis=0) - + out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) out = jnp.clip(out, min=-1.0, max=1.0) - self.clear_cache() + feat_cache.clear_cache() if not return_dict: return (out,) return FlaxDecoderOutput(sample=out) - def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: + def decode( + self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True + ) -> Union[FlaxDecoderOutput, jax.Array]: if z.shape[-1] != self.z_dim: # reshape channel last for JAX z = jnp.transpose(z, (0, 2, 3, 4, 1)) assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}" - decoded = self._decode(z).sample + decoded = self._decode(z, feat_cache).sample if not return_dict: return (decoded,) return FlaxDecoderOutput(sample=decoded) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py deleted file mode 100644 index 5cd83bbbd..000000000 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py +++ /dev/null @@ -1,287 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -from typing import Tuple, Dict, Optional, Any, Union -import jax -import math -import jax.numpy as jnp -from chex import Array -import flax.linen as nn - -from ...attention_flax import FlaxFeedForward, Fla -from ...embeddings_flax import (get_1d_rotary_pos_embed, FlaxTimesteps, FlaxTimestepEmbedding, PixArtAlphaTextProjection) - -from ....configuration_utils import ConfigMixin, flax_register_to_config -from ...modeling_flax_utils import FlaxModelMixin - - -class WanRotaryPosEmbed(nn.Module): - attention_head_dim: int - patch_size: Tuple[int, int, int] - theta: float = 10000.0 - max_seq_len: int - - @nn.compact - def __call__(self, hidden_states: Array) -> Array: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t, p_h, p_w = self.patch_size - ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - - h_dim = w_dim = 2 * (self.attention_head_dim // 6) - t_dim = self.attention_head_dim - h_dim - w_dim - - freqs = [] - for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed(dim, self.max_seq_length, self.theta, freqs_dtype=jnp.float64) - freqs.append(freq) - self.freqs = jnp.concatenate(freqs, dim=1) - - sizes = [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6, - ] - cumulative_sizes = jnp.cumsum(jnp.array(sizes)) - split_indices = cumulative_sizes[:-1] - freqs_split = jnp.split(freqs, split_indices, axis=1) - - freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) - freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) - - freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) - freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) - - freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) - freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) - - freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) - freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) - - return freqs_final - - -class WanImageEmbeddings(nn.Module): - out_features: int - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, encoder_hidden_states_image: Array) -> Array: - hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, - )(encoder_hidden_states_image) - hidden_states = FlaxFeedForward( - self.out_features, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision - )(hidden_states) - hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, - )(hidden_states) - return hidden_states - - -class WanTimeTextImageEmbeddings(nn.Module): - dim: int - time_freq_dim: int - time_proj_dim: int - text_embed_dim: int - image_embed_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, timestep: Array, encoder_hidden_states: Array, encoder_hidden_states_image: Array) -> Array: - - timestep = FlaxTimesteps( - dim=self.time_freq_dim, - flip_sin_to_cos=True, - freq_shift=0, - )(timestep) - temb = FlaxTimestepEmbedding(time_embed_dim=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype)(timestep) - timestep_proj = nn.Dense( - self.time_proj_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - )(nn.silu(temb)) - encoder_hidden_states = PixArtAlphaTextProjection( - hidden_size=self.dim, - act_fn="gelu_tanh", - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - )(encoder_hidden_states) - - if encoder_hidden_states_image is not None: - encoder_hidden_states_image = WanImageEmbeddings( - out_features=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision - )(encoder_hidden_states_image) - - return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image - - -class WanTransformerBlock(nn.Module): - dim: int - ffn_dim: int - num_heads: int - qk_norm: str = "rms_norm_across_heads" - cross_attn_norm: bool = False - eps: float = 1e-6 - added_kv_proj_dim: Optional[int] = None - - @nn.compact - def __call__(self, hidden_states: Array, encoder_hidden_states: Array, temb: Array, rotary_emb: Array): - - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 - ) - - # 1. Self-attention - norm_hidden_states = ( - nn.LayerNorm( - epsilon=self.eps, - use_bias=False, - use_scale=False, - dtype=jnp.float32, - param_dtype=jnp.float32, - )(hidden_states.astype(jnp.float32)) - * (1 + scale_msa) - + shift_msa - ).astype(hidden_states.dtype) - attn_output = FlaxWanAttention( - query_dim=self.dim, - heads=self.num_heads, - dim_head=self.dim // self.num_heads, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - attention_kernel=self.attention_kernel, - mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes, - ) - - -class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): - r""" - A Transformer model for video-like data used in the Wan model. - - Args: - patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): - 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). - num_attention_heads (`int`, defaults to `40`): - Fixed length for text embeddings. - attention_head_dim (`int`, defaults to `128`): - The number of channels in each head. - in_channels (`int`, defaults to `16`): - The number of channels in the input. - out_channels (`int`, defaults to `16`): - The number of channels in the output. - text_dim (`int`, defaults to `512`): - Input dimension for text embeddings. - freq_dim (`int`, defaults to `256`): - Dimension for sinusoidal time embeddings. - ffn_dim (`int`, defaults to `13824`): - Intermediate dimension in feed-forward network. - num_layers (`int`, defaults to `40`): - The number of layers of transformer blocks to use. - window_size (`Tuple[int]`, defaults to `(-1, -1)`): - Window size for local attention (-1 indicates global attention). - cross_attn_norm (`bool`, defaults to `True`): - Enable cross-attention normalization. - qk_norm (`bool`, defaults to `True`): - Enable query/key normalization. - eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - add_img_emb (`bool`, defaults to `False`): - Whether to use img_emb. - added_kv_proj_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the added key and value projections. If `None`, no projection is used. - """ - - patch_size: Tuple[int] = (1, 2, 2) - num_attention_heads: int = 40 - attention_head_dim: int = 128 - in_channels: int = 16 - out_channels: int = 16 - text_dim: int = 4096 - freq_dim: int = 256 - ffn_dim: int = 13824 - num_layers: int = 40 - cross_attn_norm: bool = True - qk_norm: Optional[str] = "rms_norm_across_heads" - eps: float = 1e-6 - image_dim: Optional[int] = None - added_kv_proj_dim: Optional[int] = None - rope_max_seq_len: int = 1024 - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - attention: str = "dot_product" - - @nn.compact - def __call__( - self, - hidden_states: Array, - timestep: Array, - encoder_hidden_states: Array, - encoder_hidden_states_image: Optional[Array] = None, - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> Union[Array, Dict[str, Array]]: - - inner_dim = self.num_attention_heads * self.attention_head_dim - batch_size, num_channels, num_frames, height, width = hidden_states.shape - - p_t, p_h, p_w = self.config.patch_size - post_patch_num_frames = num_frames // p_t - post_patch_height = height // p_h - post_patch_width = width // p_w - - # 1. Patch & position embedding - rotary_emb = WanRotaryPosEmbed( - attention_head_dim=self.attention_head_dim, patch_size=self.patch_size, max_seq_len=self.rope_max_seq_len - )(hidden_states) - hidden_states = nn.Conv( - features=inner_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - )(hidden_states) - flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? - flattened = hidden_states.reshape(flattened_shape) - transposed = jnp.transpose(flattened, (0, 2, 1)) - - # 2. Condition embeddings - # image_embedding_dim=1280 for I2V model - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = WanTimeTextImageEmbeddings( - dim=inner_dim, - time_freq_dim=self.freq_dim, - time_proj_dim=inner_dim * 6, - text_embed_dim=self.text_dim, - image_embed_dim=self.image_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - )(timestep, encoder_hidden_states, encoder_hidden_states_image) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py deleted file mode 100644 index eedba51cc..000000000 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py +++ /dev/null @@ -1,69 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import jax -import jax.numpy as jnp -from flax import nnx -from .... import common_types, max_logging -from ...modeling_flax_utils import FlaxModelMixin -from ....configuration_utils import ConfigMixin - -BlockSizes = common_types.BlockSizes - - -class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): - - def __init__( - self, - rngs: nnx.Rngs, - model_type="t2v", - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2038, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - eps=1e-6, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", - ): - self.path_embedding = nnx.Conv( - in_dim, - dim, - kernel_size=patch_size, - strides=patch_size, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)), - rngs=rngs, - ) - - def __call__(self, x): - x = self.path_embedding(x) - return x diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 2a8517954..b39388089 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -4,12 +4,8 @@ from huggingface_hub import hf_hub_download from safetensors import safe_open from flax.traverse_util import flatten_dict, unflatten_dict -from ..modeling_flax_pytorch_utils import ( - rename_key, - rename_key_and_reshape_tensor, - torch2jax, - validate_flax_state_dict -) +from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) + def _tuple_str_to_int(in_tuple): out_list = [] @@ -25,7 +21,9 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: device = jax.devices(device)[0] with jax.default_device(device): if hf_download: - ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors") + ckpt_path = hf_hub_download( + pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors" + ) max_logging.log(f"Load and port Wan 2.1 VAE on {device}") if ckpt_path is not None: @@ -73,5 +71,5 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: jax.clear_caches() else: raise FileNotFoundError(f"Path {ckpt_path} was not found") - - return flax_state_dict \ No newline at end of file + + return flax_state_dict diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index e71ddbde2..5e37506d0 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -14,7 +14,7 @@ limitations under the License. """ -import os +import functools import torch import torch.nn as nn import torch.nn.functional as F @@ -25,6 +25,7 @@ import unittest import pytest from absl.testing import absltest +from skimage.metrics import structural_similarity as ssim from ..models.wan.autoencoder_kl_wan import ( WanCausalConv3d, WanUpsample, @@ -35,11 +36,15 @@ WanResample, ZeroPaddedConv2D, WanAttentionBlock, + AutoencoderKLWanCache, ) from ..models.wan.wan_utils import load_wan_vae +from ..utils import load_video +from ..video_processor import VideoProcessor CACHE_T = 2 + class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer. @@ -91,19 +96,12 @@ def __init__(self, dim: int, mode: str) -> None: WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) ) elif mode == "upsample3d": - # self.resample = nn.Sequential( - # WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - # ) - # self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) raise Exception("downsample3d not supported") elif mode == "downsample2d": self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": raise Exception("downsample3d not supported") - #self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - #self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - else: self.resample = nn.Identity() @@ -156,12 +154,6 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} - def test_clear_cache(self): - key = jax.random.key(0) - rngs = nnx.Rngs(key) - wan_vae = AutoencoderKLWan(rngs=rngs) - wan_vae.clear_cache() - def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -379,7 +371,7 @@ def test_wan_decode(self): attn_scales=attn_scales, temperal_downsample=temperal_downsample, ) - + vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 t = 13 channels = 16 @@ -391,7 +383,7 @@ def test_wan_decode(self): latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) input = input / latents_std + latents_mean - dummy_output = wan_vae.decode(input) + dummy_output = wan_vae.decode(input, feat_cache=vae_cache) assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) def test_wan_encode(self): @@ -412,6 +404,7 @@ def test_wan_encode(self): attn_scales=attn_scales, temperal_downsample=temperal_downsample, ) + vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 channels = 3 t = 49 @@ -419,22 +412,51 @@ def test_wan_encode(self): width = 720 input_shape = (batch, channels, t, height, width) input = jnp.ones(input_shape) - output = wan_vae.encode(input) + output = wan_vae.encode(input, feat_cache=vae_cache) assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) - # def test_load_checkpoint(self): - # pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" - # key = jax.random.key(0) - # rngs = nnx.Rngs(key) - # wan_vae = AutoencoderKLWan.from_config( - # pretrained_model_name_or_path, - # subfolder="vae", - # rngs=rngs - # ) - # graphdef, state = nnx.split(wan_vae) - # params = state.to_pure_dict() - # # This replaces random params with the model. - # params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + def test_load_checkpoint(self): + def vae_encode(video, wan_vae, vae_cache, key): + latent = wan_vae.encode(video, feat_cache=vae_cache) + latent = latent.latent_dist.sample(key) + return latent + + pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs) + vae_cache = AutoencoderKLWanCache(wan_vae) + video_path, fps = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4", 8 + video = load_video(video_path) + + vae_scale_factor_spatial = 2 ** len(wan_vae.temperal_downsample) + video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial) + width, height = video[0].size + video = video_processor.preprocess_video(video, height=height, width=width) # .to(dtype=jnp.float32) + original_video = jnp.array(np.array(video), dtype=jnp.bfloat16) + + graphdef, state = nnx.split(wan_vae) + params = state.to_pure_dict() + # This replaces random params with the model. + params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + wan_vae = nnx.merge(graphdef, params) + + p_vae_encode = jax.jit(functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)) + original_video_shape = original_video.shape + latent = p_vae_encode(original_video) + + jitted_decode = jax.jit(functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)) + video = jitted_decode(latent)[0] + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + assert video.shape == original_video_shape + + original_video = torch.from_numpy(np.array(original_video.astype(jnp.float32))).to(dtype=torch.bfloat16) + video = torch.from_numpy(np.array(video)).to(dtype=torch.bfloat16) + video = video_processor.postprocess_video(video, output_type="np") + original_video = video_processor.postprocess_video(original_video, output_type="np") + ssim_compare = ssim(video[0], original_video[0], multichannel=True, channel_axis=-1, data_range=255) + assert ssim_compare >= 0.9999 if __name__ == "__main__": diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index 461922e94..5dfa3562f 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -110,30 +110,32 @@ def export_to_obj(mesh, output_obj_path: str = None): with open(output_obj_path, "w") as f: f.writelines("\n".join(combined_data)) + def _legacy_export_to_video( video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 ): - if is_opencv_available(): - import cv2 - else: - raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - if isinstance(video_frames[0], np.ndarray): - video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] - elif isinstance(video_frames[0], PIL.Image.Image): - video_frames = [np.array(frame) for frame in video_frames] + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - h, w, c = video_frames[0].shape - video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) - for i in range(len(video_frames)): - img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) - video_writer.write(img) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + + return output_video_path - return output_video_path def export_to_video( video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], @@ -143,64 +145,64 @@ def export_to_video( bitrate: Optional[int] = None, macro_block_size: Optional[int] = 16, ) -> str: - """ - quality: - Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to - prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. - Specifying a fixed bitrate using `bitrate` disables this parameter. - - bitrate: - Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. - Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter - rather than specifiying a fixed bitrate with this parameter. - - macro_block_size: - Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number - imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs - are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic - feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some - codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. - """ - # TODO: Dhruv. Remove by Diffusers release 0.33.0 - # Added to prevent breaking existing code - if not is_imageio_available(): - logger.warning( - ( - "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" - "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" - "Support for the OpenCV backend will be deprecated in a future Diffusers version" - ) + """ + quality: + Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to + prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. + Specifying a fixed bitrate using `bitrate` disables this parameter. + + bitrate: + Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. + Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter + rather than specifiying a fixed bitrate with this parameter. + + macro_block_size: + Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number + imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs + are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic + feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some + codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. + """ + # TODO: Dhruv. Remove by Diffusers release 0.33.0 + # Added to prevent breaking existing code + if not is_imageio_available(): + logger.warning( + ( + "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" + "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" + "Support for the OpenCV backend will be deprecated in a future Diffusers version" ) - return _legacy_export_to_video(video_frames, output_video_path, fps) - - if is_imageio_available(): - import imageio - else: - raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) - - try: - imageio.plugins.ffmpeg.get_exe() - except AttributeError: - raise AttributeError( - ( - "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" - "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" - ) + ) + return _legacy_export_to_video(video_frames, output_video_path, fps) + + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + ( + "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" + "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" ) + ) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - if isinstance(video_frames[0], np.ndarray): - video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] - elif isinstance(video_frames[0], PIL.Image.Image): - video_frames = [np.array(frame) for frame in video_frames] + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] - with imageio.get_writer( - output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size - ) as writer: - for frame in video_frames: - writer.append_data(frame) + with imageio.get_writer( + output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size + ) as writer: + for frame in video_frames: + writer.append_data(frame) - return output_video_path + return output_video_path diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index 224bad16c..d83596e8d 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -51,18 +51,20 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + def _is_package_available(pkg_name: str): - pkg_exists = importlib.util.find_spec(pkg_name) is not None - pkg_version = "N/A" + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False - if pkg_exists: - try: - pkg_version = importlib_metadata.version(pkg_name) - logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") - except (ImportError, importlib_metadata.PackageNotFoundError): - pkg_exists = False + return pkg_exists, pkg_version - return pkg_exists, pkg_version _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: @@ -298,9 +300,11 @@ def _is_package_available(pkg_name: str): except importlib_metadata.PackageNotFoundError: _peft_available = False + def is_imageio_available(): return _imageio_available + def is_torch_available(): return _torch_available @@ -727,4 +731,3 @@ def _get_module(self, module_name: str): def __reduce__(self): return (self.__class__, (self._name, self.__file__, self._import_structure)) - diff --git a/src/maxdiffusion/utils/loading_utils copy.py b/src/maxdiffusion/utils/loading_utils copy.py deleted file mode 100644 index fd66aaa4d..000000000 --- a/src/maxdiffusion/utils/loading_utils copy.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import tempfile -from typing import Any, Callable, List, Optional, Tuple, Union -from urllib.parse import unquote, urlparse - -import PIL.Image -import PIL.ImageOps -import requests - -from .import_utils import BACKENDS_MAPPING, is_imageio_available - - -def load_image( - image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None -) -> PIL.Image.Image: - """ - Loads `image` to a PIL Image. - - Args: - image (`str` or `PIL.Image.Image`): - The image to convert to the PIL Image format. - convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): - A conversion method to apply to the image after loading it. When set to `None` the image will be converted - "RGB". - - Returns: - `PIL.Image.Image`: - A PIL Image. - """ - if isinstance(image, str): - if image.startswith("http://") or image.startswith("https://"): - image = PIL.Image.open(requests.get(image, stream=True).raw) - elif os.path.isfile(image): - image = PIL.Image.open(image) - else: - raise ValueError( - f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." - ) - elif isinstance(image, PIL.Image.Image): - image = image - else: - raise ValueError( - "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." - ) - - image = PIL.ImageOps.exif_transpose(image) - - if convert_method is not None: - image = convert_method(image) - else: - image = image.convert("RGB") - - return image - - -def load_video( - video: str, - convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, -) -> List[PIL.Image.Image]: - """ - Loads `video` to a list of PIL Image. - - Args: - video (`str`): - A URL or Path to a video to convert to a list of PIL Image format. - convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): - A conversion method to apply to the video after loading it. When set to `None` the images will be converted - to "RGB". - - Returns: - `List[PIL.Image.Image]`: - The video as a list of PIL images. - """ - is_url = video.startswith("http://") or video.startswith("https://") - is_file = os.path.isfile(video) - was_tempfile_created = False - - if not (is_url or is_file): - raise ValueError( - f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." - ) - - if is_url: - response = requests.get(video, stream=True) - if response.status_code != 200: - raise ValueError(f"Failed to download video. Status code: {response.status_code}") - - parsed_url = urlparse(video) - file_name = os.path.basename(unquote(parsed_url.path)) - - suffix = os.path.splitext(file_name)[1] or ".mp4" - video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name - - was_tempfile_created = True - - video_data = response.iter_content(chunk_size=8192) - with open(video_path, "wb") as f: - for chunk in video_data: - f.write(chunk) - - video = video_path - - pil_images = [] - if video.endswith(".gif"): - gif = PIL.Image.open(video) - try: - while True: - pil_images.append(gif.copy()) - gif.seek(gif.tell() + 1) - except EOFError: - pass - - else: - if is_imageio_available(): - import imageio - else: - raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) - - try: - imageio.plugins.ffmpeg.get_exe() - except AttributeError: - raise AttributeError( - "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" - ) - - with imageio.get_reader(video) as reader: - # Read all frames - for frame in reader: - pil_images.append(PIL.Image.fromarray(frame)) - - if was_tempfile_created: - os.remove(video_path) - - if convert_method is not None: - pil_images = convert_method(pil_images) - - return pil_images - - -# Taken from `transformers`. -def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: - if "." in tensor_name: - splits = tensor_name.split(".") - for split in splits[:-1]: - new_module = getattr(module, split) - if new_module is None: - raise ValueError(f"{module} has no attribute {split}.") - module = new_module - tensor_name = splits[-1] - return module, tensor_name - - -def get_submodule_by_name(root_module, module_path: str): - current = root_module - parts = module_path.split(".") - for part in parts: - if part.isdigit(): - idx = int(part) - current = current[idx] # e.g., for nn.ModuleList or nn.Sequential - else: - current = getattr(current, part) - return current diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 77caa39e0..4a480b3c7 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -37,85 +37,86 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = image.convert("RGB") return image + def load_video( video: str, convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, ) -> List[PIL.Image.Image]: - """ - Loads `video` to a list of PIL Image. - - Args: - video (`str`): - A URL or Path to a video to convert to a list of PIL Image format. - convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): - A conversion method to apply to the video after loading it. When set to `None` the images will be converted - to "RGB". - - Returns: - `List[PIL.Image.Image]`: - The video as a list of PIL images. - """ - is_url = video.startswith("http://") or video.startswith("https://") - is_file = os.path.isfile(video) - was_tempfile_created = False - - if not (is_url or is_file): - raise ValueError( - f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." - ) - - if is_url: - response = requests.get(video, stream=True) - if response.status_code != 200: - raise ValueError(f"Failed to download video. Status code: {response.status_code}") - - parsed_url = urlparse(video) - file_name = os.path.basename(unquote(parsed_url.path)) - - suffix = os.path.splitext(file_name)[1] or ".mp4" - video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name - - was_tempfile_created = True - - video_data = response.iter_content(chunk_size=8192) - with open(video_path, "wb") as f: - for chunk in video_data: - f.write(chunk) - - video = video_path - - pil_images = [] - if video.endswith(".gif"): - gif = PIL.Image.open(video) - try: - while True: - pil_images.append(gif.copy()) - gif.seek(gif.tell() + 1) - except EOFError: - pass + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio else: - if is_imageio_available(): - import imageio - else: - raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) - - try: - imageio.plugins.ffmpeg.get_exe() - except AttributeError: - raise AttributeError( - "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" - ) - - with imageio.get_reader(video) as reader: - # Read all frames - for frame in reader: - pil_images.append(PIL.Image.fromarray(frame)) - - if was_tempfile_created: - os.remove(video_path) - - if convert_method is not None: - pil_images = convert_method(pil_images) - - return pil_images \ No newline at end of file + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images From bab4d1704944c23e451d8fdc0c78d9e4b5de0de6 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:10:30 +0000 Subject: [PATCH 22/25] remove wan from readme. --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index 2dc6523c4..e14603ac4 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ MaxDiffusion supports - [Training](#training) - [Dreambooth](#dreambooth) - [Inference](#inference) - - [Wan 2.1](#wan) - [Flux](#flux) - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) @@ -172,12 +171,6 @@ To generate images, run the following command: ```bash python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` - - ## Wan - - ```bash - python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/" - ``` ## Flux From 7a8daedd71051850a9dd9ec547521c1a44cc95fd Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:13:23 +0000 Subject: [PATCH 23/25] remove unused files. --- src/maxdiffusion/pipelines/wan/__init__.py | 0 .../pipelines/wan/pipeline_wan.py | 84 ------------------- ...heduling_flow_match_euler_discrete_flax.py | 74 ---------------- 3 files changed, 158 deletions(-) delete mode 100644 src/maxdiffusion/pipelines/wan/__init__.py delete mode 100644 src/maxdiffusion/pipelines/wan/pipeline_wan.py delete mode 100644 src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py deleted file mode 100644 index 27b5844a8..000000000 --- a/src/maxdiffusion/pipelines/wan/pipeline_wan.py +++ /dev/null @@ -1,84 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from typing import Union, List -from transformers import AutoTokenizer, UMT5EncoderModel -import torch -from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel -from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan -from ..pipeline_flax_utils import FlaxDiffusionPipeline -from ...video_processor import VideoProcessor -# from ...schedulers import FlowMatchEulerDiscreteScheduler - - -class WanPipeline(FlaxDiffusionPipeline): - - def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - transformer: WanModel, - vae: AutoencoderKLWan, - # scheduler: FlowMatchEulerDiscreteScheduler, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - # scheduler=scheduler - ) - - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - - def _get_t5_prompt_embds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - - prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state - # prompt_embeds = prompt_embeds.to(dtype=dtype) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py deleted file mode 100644 index ec29a353e..000000000 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py +++ /dev/null @@ -1,74 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax.numpy as jnp - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import ( - CommonSchedulerState, - # FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - broadcast_to_shape_from_left, -) - - -@flax.struct.dataclass -class FlowMatchEulerDiscreteSchedulerState: - common: CommonSchedulerState - - -@dataclass -class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): - state: FlowMatchEulerDiscreteSchedulerState - - -class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): - # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] - - dtype: jnp.dtype - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - shift: float = 1.0, - use_dynamic_shifting: bool = False, - base_shift: Optional[float] = 0.5, - max_shift: Optional[float] = 1.15, - base_image_seq_len: Optional[int] = 256, - max_image_seq_len: Optional[int] = 4096, - invert_sigmas: bool = False, - shift_terminal: Optional[float] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - time_shift_type: str = "exponential", - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) From b31b4ad9b105004a9388c5fcc32ab26bef76b24b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:36:41 +0000 Subject: [PATCH 24/25] more linter fixes. --- src/maxdiffusion/models/flux/util.py | 1 - src/maxdiffusion/models/wan/wan_utils.py | 4 ++-- src/maxdiffusion/tests/wan_vae_test.py | 9 +++------ src/maxdiffusion/utils/loading_utils.py | 2 +- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 504b71e9f..d856fda5d 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -4,7 +4,6 @@ import jax from jax.typing import DTypeLike -import torch # need for torch 2 jax from chex import Array from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import hf_hub_download diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index b39388089..1a9948fdb 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -3,7 +3,7 @@ from maxdiffusion import max_logging from huggingface_hub import hf_hub_download from safetensors import safe_open -from flax.traverse_util import flatten_dict, unflatten_dict +from flax.traverse_util import unflatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) @@ -12,7 +12,7 @@ def _tuple_str_to_int(in_tuple): for item in in_tuple: try: out_list.append(int(item)) - except: + except ValueError: out_list.append(item) return tuple(out_list) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 5e37506d0..7d750c8bb 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -23,7 +23,6 @@ from flax import nnx import numpy as np import unittest -import pytest from absl.testing import absltest from skimage.metrics import structural_similarity as ssim from ..models.wan.autoencoder_kl_wan import ( @@ -172,7 +171,7 @@ def test_wanrms_norm(self): dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) - assert np.allclose(output_np, torch_output_np) == True + assert np.allclose(output_np, torch_output_np) is True # --- Test Case 2: images == False --- model = TorchWanRMS_norm(dim, images=False) @@ -186,7 +185,7 @@ def test_wanrms_norm(self): dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) - assert np.allclose(output_np, torch_output_np) == True + assert np.allclose(output_np, torch_output_np) is True def test_zero_padded_conv(self): @@ -235,8 +234,6 @@ def test_wan_resample(self): w = 720 mode = "downsample2d" input_shape = (batch, dim, t, h, w) - expected_output_shape = (1, dim, 1, 240, 360) - # output dim should be (1, 96, 1, 480, 720) dummy_input = torch.ones(input_shape) torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) @@ -426,7 +423,7 @@ def vae_encode(video, wan_vae, vae_cache, key): rngs = nnx.Rngs(key) wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs) vae_cache = AutoencoderKLWanCache(wan_vae) - video_path, fps = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4", 8 + video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) vae_scale_factor_spatial = 2 ** len(wan_vae.temperal_downsample) diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 4a480b3c7..6107272c7 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import PIL.Image import PIL.ImageOps From 8c3af8bfe4a545d12dad369844cbe164c129cdcb Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 22:06:33 +0000 Subject: [PATCH 25/25] update requirements --- requirements.txt | 4 +++- requirements_with_jax_stable_stack.txt | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index defbb1512..e26b45b80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,6 @@ huggingface_hub==0.30.2 transformers==4.48.1 einops==0.8.0 sentencepiece -aqtp \ No newline at end of file +aqtp +imageio==2.37.0 +imageio-ffmpeg==0.6.0 \ No newline at end of file diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index 80ad1434e..5a88c800f 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -31,4 +31,6 @@ tensorflow-datasets>=4.9.6 tokenizers==0.21.0 torch==2.5.1 torchvision==0.20.1 -transformers==4.48.1 \ No newline at end of file +transformers==4.48.1 +imageio==2.37.0 +imageio-ffmpeg==0.6.0 \ No newline at end of file