From 15d242eb34e54613da7d5e4ee4dfd04eeb190a3e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Sun, 9 Mar 2025 20:49:12 +0000 Subject: [PATCH 01/54] wip - wan transformer --- src/maxdiffusion/models/attention_flax.py | 144 ++++++++- src/maxdiffusion/models/embeddings_flax.py | 4 +- src/maxdiffusion/models/wan/__init__.py | 15 + .../wan/transformers/transformer_flux_wan.py | 305 ++++++++++++++++++ 4 files changed, 456 insertions(+), 12 deletions(-) create mode 100644 src/maxdiffusion/models/wan/__init__.py create mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index db8626984..8008852ea 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -383,6 +383,139 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back +def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) + + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + + return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + +class FlaxWanAttention(nn.module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + attention_kernel: str = "dot_product" + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) + precision: jax.lax.Precision = None + qkv_bias: bool = False + + def setup(self): + if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + inner_dim = self.dim_head * self.heads + scale = self.dim_head**-0.5 + + self.attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=scale, + heads=self.heads, + dim_head=self.dim_head, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype, + float32_qk_product=False, + ) + + kernel_axes = ("embed", "heads") + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) + + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "heads")) + + self.query = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_q", + precision=self.precision, + ) + + self.key = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_k", + precision=self.precision, + ) + + self.value = nn.Dense( + inner_dim, + kernel_init=qkv_init_kernel, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_v", + precision=self.precision, + ) + + self.query_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + self.key_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + + self.proj_attn = nn.Dense( + self.query_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("heads", "embed")), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="to_out_0", + precision=self.precision, + ) + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def call( + self, + hidden_states: Array, + encoder_hidden_states: Optional[Array], + rotary_emb: Optional[Array], + deterministic: bool = True + ): + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_proj = self.query(hidden_states) + key_proj = self.key(encoder_hidden_states) + value_proj = self.value(encoder_hidden_states) + + query_proj = self.query_norm(query_proj) + key_proj = self.key_norm(key_proj) + + if rotary_emb: + query_proj, key_proj = self.apply_rope(query_proj, key_proj, rotary_emb) + + query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) + key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) + value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + + hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + + hidden_states = self.proj_attn(hidden_states) + hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxFluxAttention(nn.Module): query_dim: int @@ -493,15 +626,6 @@ def setup(self): param_dtype=self.weights_dtype, ) - def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: - xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) - - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - - return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) - def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): qkv_proj = self.qkv(hidden_states) @@ -535,7 +659,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) - query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb) + query_proj, key_proj = apply_rope(query_proj, key_proj, image_rotary_emb) query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 42ca4b950..3377ef6cc 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -56,7 +56,6 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal - class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. @@ -91,7 +90,8 @@ class FlaxTimesteps(nn.Module): dim: int = 32 flip_sin_to_cos: bool = False - freq_shift: float = 1 + freq_shift: float = 1.0 + scale: int = 1 @nn.compact def __call__(self, timesteps): diff --git a/src/maxdiffusion/models/wan/__init__.py b/src/maxdiffusion/models/wan/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/wan/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py new file mode 100644 index 000000000..968798c5d --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py @@ -0,0 +1,305 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from typing import Tuple, Dict, Optional, Any, Union +import jax +import math +import jax.numpy as jnp +from chex import Array +import flax +import flax.linen as nn + +from ...attention_flax import FlaxFeedForward, Fla +from ...embeddings_flax import ( + get_1d_rotary_pos_embed, + FlaxTimesteps, + FlaxTimestepEmbedding, + PixArtAlphaTextProjection +) + +from ....configuration_utils import ConfigMixin, flax_register_to_config +from ...modeling_flax_utils import FlaxModelMixin + +class WanRotaryPosEmbed(nn.Module): + attention_head_dim: int + patch_size: Tuple[int, int, int] + theta: float = 10000.0 + max_seq_len: int + + @nn.compact + def __call__(self, hidden_states: Array) -> Array: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + h_dim = w_dim = 2 * (self.attention_head_dim // 6) + t_dim = self.attention_head_dim - h_dim - w_dim + + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed(dim, self.max_seq_length, self.theta, freqs_dtype=jnp.float64) + freqs.append(freq) + self.freqs = jnp.concatenate(freqs, dim=1) + + sizes = [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6 + ] + cumulative_sizes = jnp.cumsum(jnp.array(sizes)) + split_indices = cumulative_sizes[:-1] + freqs_split = jnp.split(freqs, split_indices, axis=1) + + freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) + freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) + + freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) + freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) + + freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) + freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) + + freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) + freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) + + return freqs_final + +class WanImageEmbeddings(nn.Module): + out_features: int + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, encoder_hidden_states_image: Array) -> Array: + hidden_states = nn.LayerNorm( + dtype=jnp.float32, + param_dtype=jnp.float32, + )(encoder_hidden_states_image) + hidden_states = FlaxFeedForward( + self.out_features, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(hidden_states) + hidden_states = nn.LayerNorm( + dtype=jnp.float32, + param_dtype=jnp.float32, + )(hidden_states) + return hidden_states + + +class WanTimeTextImageEmbeddings(nn.Module): + dim: int + time_freq_dim: int + time_proj_dim: int + text_embed_dim: int + image_embed_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, timestep: Array, encoder_hidden_states: Array, encoder_hidden_states_image: Array) -> Array: + + timestep = FlaxTimesteps( + dim=self.time_freq_dim, + flip_sin_to_cos=True, + freq_shift=0, + )(timestep) + temb = FlaxTimestepEmbedding( + time_embed_dim=self.dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype + )(timestep) + timestep_proj = nn.Dense( + self.time_proj_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + )(nn.silu(temb)) + encoder_hidden_states = PixArtAlphaTextProjection( + hidden_size=self.dim, + act_fn="gelu_tanh", + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(encoder_hidden_states) + + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = WanImageEmbeddings( + out_features=self.dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + +class WanTransformerBlock(nn.Module): + dim: int + ffn_dim: int + num_heads: int + qk_norm: str = "rms_norm_across_heads" + cross_attn_norm: bool = False + eps: float = 1e-6 + added_kv_proj_dim: Optional[int] = None + + @nn.compact + def __call__( + self, + hidden_states: Array, + encoder_hidden_states: Array, + temb: Array, + rotary_emb: Array + ): + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + ) + + # 1. Self-attention + norm_hidden_states = (nn.LayerNorm( + epsilon=self.eps, + use_bias=False, + use_scale=False, + dtype=jnp.float32, + param_dtype=jnp.float32, + )(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) + attn_output = FlaxWanAttention( + query_dim=self.dim, + heads=self.num_heads, + dim_head=self.dim // self.num_heads, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + flash_block_sizes=self.flash_block_sizes, + + ) + +class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data used in the Wan model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + patch_size: Tuple[int] = (1, 2, 2) + num_attention_heads: int = 40 + attention_head_dim: int = 128 + in_channels: int = 16 + out_channels: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 13824 + num_layers: int = 40 + cross_attn_norm: bool = True + qk_norm: Optional[str] = "rms_norm_across_heads" + eps: float = 1e-6 + image_dim: Optional[int] = None + added_kv_proj_dim: Optional[int] = None + rope_max_seq_len: int = 1024 + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + attention: str = "dot_product" + + @nn.compact + def __call__( + self, + hidden_states: Array, + timestep: Array, + encoder_hidden_states: Array, + encoder_hidden_states_image: Optional[Array] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None + ) -> Union[Array, Dict[str, Array]]: + + inner_dim = self.num_attention_heads * self.attention_head_dim + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. Patch & position embedding + rotary_emb = WanRotaryPosEmbed( + attention_head_dim=self.attention_head_dim, + patch_size=self.patch_size, + max_seq_len=self.rope_max_seq_len + )(hidden_states) + hidden_states = nn.Conv( + features=inner_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + )(hidden_states) + flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? + flattened = hidden_states.reshape(flattened_shape) + transposed = jnp.transpose(flattened, (0, 2, 1)) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = WanTimeTextImageEmbeddings( + dim=inner_dim, + time_freq_dim=self.freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=self.text_dim, + image_embed_dim=self.image_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(timestep, encoder_hidden_states, encoder_hidden_states_image) + From 9b63238cffeced18a4b7d1836f4f093bff564d2d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 16 Apr 2025 15:06:42 +0000 Subject: [PATCH 02/54] adding nnx - wip --- src/maxdiffusion/configs/base_wan_t2v.yml | 271 ++++++++++++++++++ src/maxdiffusion/generate_wan.py | 33 +++ src/maxdiffusion/models/attention_flax.py | 4 +- .../wan/transformers/transformer_flux_wan.py | 1 - .../transformers/transformer_flux_wan_nnx.py | 70 +++++ 5 files changed, 376 insertions(+), 3 deletions(-) create mode 100644 src/maxdiffusion/configs/base_wan_t2v.yml create mode 100644 src/maxdiffusion/generate_wan.py create mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml new file mode 100644 index 000000000..944153d64 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -0,0 +1,271 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' +clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' +t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' + +# Flux params +flux_name: "flux-dev" +max_sequence_length: 512 +time_shift: True +base_shift: 0.5 +max_shift: 1.15 +# offloads t5 encoder after text encoding to save memory. +offload_encoders: True + + +unet_checkpoint: '' +revision: 'refs/pr/95' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te + +flash_block_sizes: {} +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +# flash_block_sizes: { +# "block_q" : 1536, +# "block_kv_compute" : 1536, +# "block_kv" : 1536, +# "block_q_dkv" : 1536, +# "block_kv_dkv" : 1536, +# "block_kv_dkv_compute" : 1536, +# "block_q_dq" : 1536, +# "block_kv_dq" : 1536 +# } +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_unet: False + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 4.e-7 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 200 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1 + +warmup_steps_fraction: 0.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 1.e-2 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A magical castle in the middle of a forest, artistic drawing" +prompt_2: "A magical castle in the middle of a forest, artistic drawing" +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 3.5 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 50 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py new file mode 100644 index 000000000..599fbc189 --- /dev/null +++ b/src/maxdiffusion/generate_wan.py @@ -0,0 +1,33 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Callable, List, Union, Sequence +from flax import nnx +from absl import app +from maxdiffusion import pyconfig, max_logging +from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel + +def run(config): + max_logging.log("Wan 2.1 inference script") + + wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b22ed60e9..fc838c9db 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -14,7 +14,7 @@ import functools import math - +from typing import Optional import flax.linen as nn import jax import jax.numpy as jnp @@ -414,7 +414,7 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) -class FlaxWanAttention(nn.module): +class FlaxWanAttention(nn.Module): query_dim: int heads: int = 8 dim_head: int = 64 diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py index 968798c5d..cdcecf80b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py @@ -19,7 +19,6 @@ import math import jax.numpy as jnp from chex import Array -import flax import flax.linen as nn from ...attention_flax import FlaxFeedForward, Fla diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py new file mode 100644 index 000000000..e5c1b7145 --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py @@ -0,0 +1,70 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import jax +import jax.numpy as jnp +from flax import nnx +from .... import common_types, max_logging +from ...modeling_flax_utils import FlaxModelMixin +from ....configuration_utils import ConfigMixin + +BlockSizes = common_types.BlockSizes + +class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( + self, + rngs: nnx.Rngs, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2038, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.path_embedding = nnx.Conv( + in_dim, + dim, + kernel_size=patch_size, + strides=patch_size, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ("batch",) + ), + rngs=rngs + ) + + def __call__(self, x): + x = self.path_embedding(x) + return x From 9276f26180fbc8e579b9dbf38f02d78a7c8273c8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 18 Apr 2025 23:06:25 +0000 Subject: [PATCH 03/54] wan pipeline wip --- README.md | 8 ++ src/maxdiffusion/configs/base_wan_t2v.yml | 4 +- src/maxdiffusion/generate_wan.py | 11 +- src/maxdiffusion/image_processor.py | 46 +++++++ .../models/wan/autoencoder_kl_wan.py | 87 ++++++++++++++ src/maxdiffusion/pipelines/wan/__init__.py | 0 .../pipelines/wan/pipeline_wan.py | 84 +++++++++++++ src/maxdiffusion/schedulers/__init__.py | 2 + ...heduling_flow_match_euler_discrete_flax.py | 71 +++++++++++ src/maxdiffusion/video_processor.py | 113 ++++++++++++++++++ 10 files changed, 422 insertions(+), 4 deletions(-) create mode 100644 src/maxdiffusion/models/wan/autoencoder_kl_wan.py create mode 100644 src/maxdiffusion/pipelines/wan/__init__.py create mode 100644 src/maxdiffusion/pipelines/wan/pipeline_wan.py create mode 100644 src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py create mode 100644 src/maxdiffusion/video_processor.py diff --git a/README.md b/README.md index 081d65e8e..2dc6523c4 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ MaxDiffusion supports - [Training](#training) - [Dreambooth](#dreambooth) - [Inference](#inference) + - [Wan 2.1](#wan) - [Flux](#flux) - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) @@ -171,6 +172,13 @@ To generate images, run the following command: ```bash python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` + + ## Wan + + ```bash + python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/" + ``` + ## Flux First make sure you have permissions to access the Flux repos in Huggingface. diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml index 944153d64..28ef6e77e 100644 --- a/src/maxdiffusion/configs/base_wan_t2v.yml +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -23,9 +23,7 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 -pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' -clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' -t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' # Flux params flux_name: "flux-dev" diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 599fbc189..27839954b 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -19,11 +19,20 @@ from absl import app from maxdiffusion import pyconfig, max_logging from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel +from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline def run(config): max_logging.log("Wan 2.1 inference script") - wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + pipeline, params = WanPipeline.from_pretrained( + config.pretrained_model_name_or_path, + vae=None, + transformer=None + ) + breakpoint() + + #wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 788e1c94e..2f99d0f7f 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -35,6 +35,52 @@ List[torch.FloatTensor], ] +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + class VaeImageProcessor(ConfigMixin): """ diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py new file mode 100644 index 000000000..4dc878ce4 --- /dev/null +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -0,0 +1,87 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Tuple, List +from flax import nnx +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin + +class WanEncoder3d(nnx.Module): + pass + +class WanCausalConv3d(nnx.Module): + pass + +class WanDecoder3d(nnx.Module): + pass + +class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1,2,4,4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temporal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + ): + self.z_dim = z_dim + self.temporal_downsample = temporal_downsample + self.temporal_upsample = temporal_downsample[::-1] + + self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout + ) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py new file mode 100644 index 000000000..df4cbb748 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/pipeline_wan.py @@ -0,0 +1,84 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Union, List +from transformers import AutoTokenizer, UMT5EncoderModel +import torch +from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ...video_processor import VideoProcessor +from ...schedulers import FlowMatchEulerDiscreteScheduler + +class WanPipeline(FlaxDiffusionPipeline): + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + + def _get_t5_prompt_embds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state + # prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds \ No newline at end of file diff --git a/src/maxdiffusion/schedulers/__init__.py b/src/maxdiffusion/schedulers/__init__.py index b630948e0..edd249de1 100644 --- a/src/maxdiffusion/schedulers/__init__.py +++ b/src/maxdiffusion/schedulers/__init__.py @@ -48,6 +48,7 @@ _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"] + _import_structure["scheduling_flow_match_euler_discrete_flax"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_utils_flax"] = [ "FlaxKarrasDiffusionSchedulers", "FlaxSchedulerMixin", @@ -73,6 +74,7 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_flow_match_euler_discrete_flax import FlowMatchEulerDiscreteScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py new file mode 100644 index 000000000..db96f53fd --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py @@ -0,0 +1,71 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + # FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + +@flax.struct.dataclass +class FlowMatchEulerDiscreteSchedulerState: + common: CommonSchedulerState + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): + state: FlowMatchEulerDiscreteSchedulerState + +class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + dtype: jnp.dtype = jnp.float32 + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) \ No newline at end of file diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py new file mode 100644 index 000000000..2da782b46 --- /dev/null +++ b/src/maxdiffusion/video_processor.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) + + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs From 120ceb3b98e229cc0adc7963eca28abeaf33aea0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 22 Apr 2025 18:26:04 +0000 Subject: [PATCH 04/54] wip - vae --- README.md | 8 + src/maxdiffusion/configs/base_wan_t2v.yml | 4 +- src/maxdiffusion/generate_wan.py | 191 ++++++++++++++++- src/maxdiffusion/image_processor.py | 46 +++++ .../models/wan/autoencoder_kl_wan.py | 193 ++++++++++++++++++ src/maxdiffusion/pipelines/wan/__init__.py | 0 .../pipelines/wan/pipeline_wan.py | 84 ++++++++ src/maxdiffusion/schedulers/__init__.py | 2 + ...heduling_flow_match_euler_discrete_flax.py | 71 +++++++ src/maxdiffusion/tests/wan_vae_test.py | 30 +++ src/maxdiffusion/video_processor.py | 113 ++++++++++ 11 files changed, 737 insertions(+), 5 deletions(-) create mode 100644 src/maxdiffusion/models/wan/autoencoder_kl_wan.py create mode 100644 src/maxdiffusion/pipelines/wan/__init__.py create mode 100644 src/maxdiffusion/pipelines/wan/pipeline_wan.py create mode 100644 src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py create mode 100644 src/maxdiffusion/tests/wan_vae_test.py create mode 100644 src/maxdiffusion/video_processor.py diff --git a/README.md b/README.md index 081d65e8e..2dc6523c4 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ MaxDiffusion supports - [Training](#training) - [Dreambooth](#dreambooth) - [Inference](#inference) + - [Wan 2.1](#wan) - [Flux](#flux) - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) @@ -171,6 +172,13 @@ To generate images, run the following command: ```bash python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` + + ## Wan + + ```bash + python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/" + ``` + ## Flux First make sure you have permissions to access the Flux repos in Huggingface. diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml index 944153d64..28ef6e77e 100644 --- a/src/maxdiffusion/configs/base_wan_t2v.yml +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -23,9 +23,7 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 -pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' -clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' -t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' # Flux params flux_name: "flux-dev" diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 599fbc189..6945a62f5 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -13,18 +13,205 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from typing import Callable, List, Union, Sequence +import html +from typing import Callable, List, Union, Sequence, Optional +import time +import torch +import ftfy +import regex as re +import jax +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P from flax import nnx from absl import app +from transformers import AutoTokenizer, UMT5EncoderModel from maxdiffusion import pyconfig, max_logging from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel +from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline + +from maxdiffusion.max_utils import ( + device_put_replicated, + get_memory_allocations, + create_device_mesh, + get_flash_block_sizes, + get_precision, + setup_initial_state, +) + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + +def _get_t5_prompt_embeds( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + +def encode_prompt( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds def run(config): max_logging.log("Wan 2.1 inference script") + rng = jax.random.key(config.seed) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + global_batch_size = config.per_device_batch_size * jax.local_device_count() + + tokenizer = AutoTokenizer.from_pretrained( + config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype + ) + text_encoder = UMT5EncoderModel.from_pretrained( + config.pretrained_model_name_or_path, subfolder="text_encoder", + ) + s0 = time.perf_counter() + prompt_embeds, negative_prompt_embeds = encode_prompt( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=config.prompt, + negative_prompt=config.negative_prompt + ) + max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") + + # pipeline, params = WanPipeline.from_pretrained( + # config.pretrained_model_name_or_path, + # #vae=None, + # #transformer=None + # ) + # breakpoint() + wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 788e1c94e..2f99d0f7f 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -35,6 +35,52 @@ List[torch.FloatTensor], ] +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + class VaeImageProcessor(ConfigMixin): """ diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py new file mode 100644 index 000000000..65da97bc9 --- /dev/null +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -0,0 +1,193 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Tuple, List, Sequence, Union, Optional + +import jax +import jax.numpy as jnp +from flax import nnx +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin +from ... import common_types + +BlockSizes = common_types.BlockSizes + +_ACTIVATIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + "relu": jax.nn.relu, + "gelu": jax.nn.gelu, + "mish": jax.nn.mish +} + +def get_activation(name: str): + func = _ACTIVATIONS.get(name) + if func is None: + raise ValueError(f"Unknown activation function: {name}") + return func + +# Helper to ensure kernel_size, stride, padding are tuples of 3 integers +def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: + """Canonicalizes a value to a tuple of integers.""" + if isinstance(x, int): + return (x,) * rank + elif isinstance(x, Sequence) and len(x) == rank: + return tuple(x) + else: + raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") + +class WanCausalConv3d(nnx.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + *, # Mark subsequent arguments as keyword-only + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + use_bias: bool = True, + rngs: nnx.Rngs, # rngs are required for initializing parameters, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') + self.stride = _canonicalize_tuple(stride, 3, 'stride') + padding_tuple = _canonicalize_tuple(padding, 3, 'padding') # (D, H, W) padding amounts + + self._causal_padding = ( + (0, 0), # Batch dimension - no padding + (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) + (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding + (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding + (0, 0) # Channel dimension - no padding + ) + + # Store the amount of padding needed *before* the depth dimension for caching logoic + self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + + self.conv = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + use_bias=use_bias, + padding='VALID', # Handle padding manually + rngs=rngs + ) + + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Array: + current_padding = list(self._causal_padding) # Mutable copy + padding_needed = self._depth_padding_before + + if cache_x is not None and padding_needed > 0: + # Ensure cache has same spatial/channel dims, potentially different depth + assert cache_x.shape[0] == x.shape[0] and \ + cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" + + cache_len = cache_x.shape[1] + x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) + + padding_needed -= cache_len + if padding_needed < 0: + # Cache longer than needed padding, trim from start + x = x[:, -padding_needed:, ...] + current_padding[1] = (0, 0) # No explicit padding needed now + else: + # Update depth padding needed + current_padding[1] = (padding_needed, 0) + + # Apply padding if any dimension requires it + padding_to_apply = tuple(current_padding) + if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): + x_padded = jnp.pad(x, padding_to_apply, mode='constant', constant_values=0.0) + else: + x_padded = x + + out = self.conv(x_padded) + return out + +class WanRMS_norm(nnx.Module): + def __init__( + self, + dim: int, + *, + eps: float = 1e-6, + use_bias: bool = False, + rngs: nnx.Rngs + ): + self.eps = eps + shape = (dim,) + self.scale = dim ** 0.5 + # Initialize gamma as parameter + self.gamma = nnx.Param(jax.random.ones(rngs.params(), shape)) + if use_bias: + self.bias = nnx.Param(jax.random.zeros(rngs.params(), shape)) + else: + self.bias = None + + def __call__(self, x: jax.Array) -> jax.Array: + # Expects input channels in the last dimension + variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + inv_std = jax.lax.rsqrt(variance + self.eps) + normalized = x * inv_std * self.gamma.value * self.scale + if self.bias: + return normalized + self.bias.value + return normalized + + +class WanEncoder3d(nnx.Module): + pass + +class WanCausalConv3d(nnx.Module): + pass + +class WanDecoder3d(nnx.Module): + pass + +class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1,2,4,4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temporal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571,-0.7089,-0.9113,0.1075,-0.1745,0.9653,-0.1517, 1.5508, + 0.4134,-0.0715,0.5517,-0.3632,-0.1922,-0.9497,0.2503,-0.2921, + ], + latents_std: List[float] = [ + 2.8184,1.4541,2.3275,2.6558,1.2196,1.7708,2.6052,2.0743, + 3.2687,2.1526,2.8652,1.5579,1.6382,1.1253,2.8251,1.9160, + ], + ): + self.z_dim = z_dim + self.temporal_downsample = temporal_downsample + self.temporal_upsample = temporal_downsample[::-1] + + self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout + ) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py new file mode 100644 index 000000000..6bd9fb09b --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/pipeline_wan.py @@ -0,0 +1,84 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Union, List +from transformers import AutoTokenizer, UMT5EncoderModel +import torch +from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ...video_processor import VideoProcessor +#from ...schedulers import FlowMatchEulerDiscreteScheduler + +class WanPipeline(FlaxDiffusionPipeline): + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanModel, + vae: AutoencoderKLWan, + #scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + #scheduler=scheduler + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + + def _get_t5_prompt_embds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state + # prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds \ No newline at end of file diff --git a/src/maxdiffusion/schedulers/__init__.py b/src/maxdiffusion/schedulers/__init__.py index b630948e0..edd249de1 100644 --- a/src/maxdiffusion/schedulers/__init__.py +++ b/src/maxdiffusion/schedulers/__init__.py @@ -48,6 +48,7 @@ _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"] + _import_structure["scheduling_flow_match_euler_discrete_flax"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_utils_flax"] = [ "FlaxKarrasDiffusionSchedulers", "FlaxSchedulerMixin", @@ -73,6 +74,7 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_flow_match_euler_discrete_flax import FlowMatchEulerDiscreteScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py new file mode 100644 index 000000000..db96f53fd --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py @@ -0,0 +1,71 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + # FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + +@flax.struct.dataclass +class FlowMatchEulerDiscreteSchedulerState: + common: CommonSchedulerState + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): + state: FlowMatchEulerDiscreteSchedulerState + +class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + dtype: jnp.dtype = jnp.float32 + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py new file mode 100644 index 000000000..d82a129e0 --- /dev/null +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -0,0 +1,30 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import unittest +import pytest +from absl.testing import absltest + +class WanVaeTest(unittest.TestCase): + def setUp(self): + WanVaeTest.dummy_data = {} + + # def test_3d_conv(self): + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py new file mode 100644 index 000000000..2da782b46 --- /dev/null +++ b/src/maxdiffusion/video_processor.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) + + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs From cc11bb155f92e54cbb52bb7b39814c2fd7614953 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Apr 2025 20:11:40 +0000 Subject: [PATCH 05/54] added tests for a couple of wan vae layers. --- .../models/wan/autoencoder_kl_wan.py | 369 +++++++++++++++++- src/maxdiffusion/tests/wan_vae_test.py | 134 ++++++- 2 files changed, 485 insertions(+), 18 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 65da97bc9..9d0f89621 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -22,9 +22,12 @@ from ...configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ... import common_types +from ..vae_flax import FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution BlockSizes = common_types.BlockSizes +CACHE_T = 2 + _ACTIVATIONS = { "swish": jax.nn.silu, "silu": jax.nn.silu, @@ -128,43 +131,350 @@ class WanRMS_norm(nnx.Module): def __init__( self, dim: int, - *, + rngs: nnx.Rngs, + channel_first: bool = True, + images: bool = True, eps: float = 1e-6, use_bias: bool = False, - rngs: nnx.Rngs ): + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim, ) self.eps = eps - shape = (dim,) + self.channel_first = channel_first self.scale = dim ** 0.5 # Initialize gamma as parameter - self.gamma = nnx.Param(jax.random.ones(rngs.params(), shape)) + self.gamma = nnx.Param(jnp.ones(shape)) if use_bias: - self.bias = nnx.Param(jax.random.zeros(rngs.params(), shape)) + self.bias = nnx.Param(jnp.zeros(shape)) else: - self.bias = None + self.bias = 0 def __call__(self, x: jax.Array) -> jax.Array: - # Expects input channels in the last dimension - variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) - inv_std = jax.lax.rsqrt(variance + self.eps) - normalized = x * inv_std * self.gamma.value * self.scale + normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True) + normalized = x / jnp.maximum(normalized, self.eps) + normalized = normalized * self.scale * self.gamma if self.bias: return normalized + self.bias.value return normalized +class WanUpsample(nnx.Module): + def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'): + # scale_factor for (H, W) + # JAX resize works on spatial dims, H, W assumming (N, D, H, W, C) or (N, H, W, C) + self.scale_factor = scale_factor + self.method = method + + def __call__(self, x: jax.Array) -> jax.Array: + input_dtype = x.dtype + in_shape = x.shape + is_3d = len(in_shape) == 5 + n, d, h, w, c = in_shape if is_3d else(in_shape[0], 1, in_shape[1], in_shape[2], in_shape[3]) + + target_h = int(h * self.scale_factor[0]) + target_w = int(w * self.scale_factor[1]) + + # jax.image.resize expects (..., H, W, C) + if is_3d: + x_reshaped = x.reshape(n * d, h, w, c) + out_reshaped = jax.image.resize(x_reshaped.astype(jnp.float32), + (n * d, target_h, target_w, c), + method=self.method) + out = out_reshaped.reshape(n, d, target_h, target_w, c) + else: # Asumming (N, H, W, C) + out = jax.image.resize(x.astype(jnp.float32), + (n, target_h, target_w, c), + method=self.method) + + return out.astype(input_dtype) + +class Identity(nnx.Module): + def __call__(self, x): + return x + +class ZeroPaddedConv2D(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.conv = nnx.Conv( + dim, + dim, + kernel_size=(1, 3, 3), + padding='SAME', + use_bias=True, + rngs=rngs + ) + + def __call__(self, x): + # This pad assumes (B, C, H, W) + x = jax.lax.pad(x, 0.0, [(0, 0, 0), (0, 0, 0), (0, 1, 0), (0, 1, 0)]) + return self.conv(x) + + +class WanResample(nnx.Module): + def __init__( + self, + dim: int, + mode: str, + rngs: nnx.Rngs, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.dim = dim + self.mode = mode + self.time_conv = None + + if mode == "upsample2d": + self.resample = nnx.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding='SAME', + use_bias=True, + rngs=rngs, + ) + ) + elif mode == "upsample3d": + self.resample = nnx.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding='SAME', + use_bias=True, + rngs=rngs, + ) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs) + elif mode == "downsample2d": + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + elif mode == "downsample3d": + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + else: + self.resample = Identity() + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: + return x + +class WanResidualBlock(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + non_linearity: str = "silu", + ): + pass + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x + +class WanAttentionBlock(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs + ): + self.dim = dim + + def __call__(self, x: jax.Array): + return x + +class WanMidBlock(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1 + ): + self.dim = dim + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x + +class WanUpBlock(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu" + ): + # Create layers list + self.resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + self.resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) + current_dim = out_dim + + # Add upsampling layer if needed. + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x class WanEncoder3d(nnx.Module): - pass + def __init__( + self, + rngs: nnx.Rngs, + dim: int =128, + z_dim: int = 4, + dim_mult = [1, 2, 4, 4], + num_res_blocks = 2, + attn_scales = [], + temperal_downsample = [True, True, False], + dropout = 0.0, + non_linearity: str = 'silu', + ): + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) -class WanCausalConv3d(nnx.Module): - pass + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1, rngs=rngs) + + # downsample blocks + self.down_blocks = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim=in_dim, out_dim=out_dim, dropout=dropout, rngs=rngs)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(dim=out_dim, rngs=rngs)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) + + # middle_blocks + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1, rngs=rngs) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs) + self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x class WanDecoder3d(nnx.Module): - pass + r""" + A 3D decoder module. + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + def __init__( + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 128, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales = List[float], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = WanCausalConv3d(in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1, rngs=rngs) + + # middle_blocks + self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Crete and add the upsampling block + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + rngs=rngs + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *=2.0 + + # output blocks + self.norm_out = nnx.RMSNorm(num_features=out_dim, ) + self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs) + self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs) + + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): def __init__( self, + rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult: Tuple[int] = [1,2,4,4], @@ -186,8 +496,35 @@ def __init__( self.temporal_upsample = temporal_downsample[::-1] self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) - self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs) self.decoder = WanDecoder3d( base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout - ) \ No newline at end of file + ) + self.clear_cache() + + def clear_cache(self): + """ Resets cache dictionaries and indices""" + def _count_conv3d(module): + count = 0 + node_types = nnx.graph.iter_graph(module, nnx.Module) + for node in node_types: + if isinstance(node.value, WanCausalConv3d): + count +=1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + + def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: + """ Encode video into latent distribution.""" + if x.shape[-1] != 3: + raise ValueError(f"Expected input shape (N, D, H, W, 3), got {x.shape}") + + self.clear_cache() \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index d82a129e0..a65f0c425 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -15,16 +15,146 @@ """ import os +import jax +import jax.numpy as jnp +from flax import nnx +import numpy as np import unittest import pytest from absl.testing import absltest +from ..models.wan.autoencoder_kl_wan import WanCausalConv3d, WanUpsample, AutoencoderKLWan, WanRMS_norm class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} - # def test_3d_conv(self): + # def test_clear_cache(self): + # key = jax.random.key(0) + # rngs = nnx.Rngs(key) + # wan_vae = AutoencoderKLWan(rngs=rngs) + # wan_vae.clear_cache() + + def test_wanrms_norm(self): + """Test against the Pytorch implementation""" + import torch + import torch.nn as nn + import torch.nn.functional as F + + class TorchWanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + # --- Test Case 1: images == True --- + model = TorchWanRMS_norm(2) + input_shape = (1, 2, 2, 2, 3) + input = torch.ones(input_shape) + torch_output = model(input) + torch_output_np = torch_output.detach().numpy() + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wanrms_norm = WanRMS_norm(dim=2, rngs=rngs) + input_shape = (1, 2, 2, 2, 3) + dummy_input = jnp.ones(input_shape) + output = wanrms_norm(dummy_input) + output_np = np.array(output) + assert np.allclose(output_np, torch_output_np) == True + + # --- Test Case 2: images == False --- + model = TorchWanRMS_norm(2, images=False) + input_shape = (1, 2, 2, 2, 3) + input = torch.ones(input_shape) + torch_output = model(input) + torch_output_np = torch_output.detach().numpy() + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wanrms_norm = WanRMS_norm(dim=2, rngs=rngs, images=False) + input_shape = (1, 2, 2, 2, 3) + dummy_input = jnp.ones(input_shape) + output = wanrms_norm(dummy_input) + output_np = np.array(output) + assert np.allclose(output_np, torch_output_np) == True + + def test_wan_upsample(self): + batch_size=1 + in_depth, in_height, in_width = 10, 32, 32 + in_channels = 3 + + dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + + upsample = WanUpsample(scale_factor=(2.0, 2.0)) + + # --- Test Case 1: depth > 1 --- + output = upsample(dummy_input) + assert output.shape == (1, 10, 64, 64, 3) + + in_depth = 1 + dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + # --- Test Case 1: depth == 1 --- + output = upsample(dummy_input) + assert output.shape == (1, 1, 64, 64, 3) + + def test_3d_conv(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size=1 + in_depth, in_height, in_width = 10, 32, 32 + in_channels = 3 + out_channels = 16 + kernel_d, kernel_h, kernel_w = 3, 3, 3 # Kernel size (Depth, Height, Width) + padding_d, padding_h, padding_w = 1, 1, 1 # Base padding (Depth, Height, Width) + + # Create dummy input data + dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + + # Create dummy cache data (from a previous step) + cache_depth = 2 * padding_d + dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) + + # Instantiate the module + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs # Pass rngs for initialization + ) + + # --- Test Case 1: No Cache --- + output_no_cache = causal_conv_layer(dummy_input) + assert output_no_cache.shape == (1, 10, 32, 32, 16) + + # --- Test Case 2: With Cache --- + output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) + assert output_with_cache.shape == (1, 10, 32, 32, 16) + + # --- Test Case 3: With Cache larger than padding --- + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) + output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) + assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) - if __name__ == "__main__": absltest.main() \ No newline at end of file From 4e443b82e58a0a4adca9678209ac91927b6bcbb1 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Apr 2025 22:18:31 +0000 Subject: [PATCH 06/54] add unit tests to wan vae padded conv. --- src/maxdiffusion/generate_wan.py | 1 - .../models/wan/autoencoder_kl_wan.py | 30 ++++++++--- src/maxdiffusion/tests/wan_vae_test.py | 51 +++++++++++++++---- 3 files changed, 66 insertions(+), 16 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index c35aafecc..e19d783de 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -214,7 +214,6 @@ def run(config): vae=None, transformer=None ) - breakpoint() #wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 9d0f89621..61473ca5b 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -192,10 +192,16 @@ def __call__(self, x): return x class ZeroPaddedConv2D(nnx.Module): + """ + Module for adding padding before conv. + Currently it does not add any padding. + """ def __init__( self, dim: int, rngs: nnx.Rngs, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, @@ -204,18 +210,18 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", ): + kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') + stride = _canonicalize_tuple(stride, 3, 'stride') self.conv = nnx.Conv( dim, dim, - kernel_size=(1, 3, 3), - padding='SAME', + kernel_size=kernel_size, + strides=stride, use_bias=True, rngs=rngs ) def __call__(self, x): - # This pad assumes (B, C, H, W) - x = jax.lax.pad(x, 0.0, [(0, 0, 0), (0, 0, 0), (0, 1, 0), (0, 1, 0)]) return self.conv(x) @@ -263,9 +269,21 @@ def __init__( ) self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs) elif mode == "downsample2d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + # TODO - do I need to transpose? + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(1, 3, 3), + stride=(1, 2, 2) + ) elif mode == "downsample3d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs) + # TODO - do I need to transpose? + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(1, 3, 3), + stride=(1, 2, 2) + ) else: self.resample = Identity() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index a65f0c425..3aa64a133 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -22,7 +22,13 @@ import unittest import pytest from absl.testing import absltest -from ..models.wan.autoencoder_kl_wan import WanCausalConv3d, WanUpsample, AutoencoderKLWan, WanRMS_norm +from ..models.wan.autoencoder_kl_wan import ( + WanCausalConv3d, + WanUpsample, + AutoencoderKLWan, + WanRMS_norm, + ZeroPaddedConv2D +) class WanVaeTest(unittest.TestCase): def setUp(self): @@ -66,36 +72,63 @@ def forward(self, x): return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias # --- Test Case 1: images == True --- - model = TorchWanRMS_norm(2) - input_shape = (1, 2, 2, 2, 3) + dim = 96 + input_shape = (1, 96, 1, 480, 720) + + model = TorchWanRMS_norm(dim) input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() key = jax.random.key(0) rngs = nnx.Rngs(key) - wanrms_norm = WanRMS_norm(dim=2, rngs=rngs) - input_shape = (1, 2, 2, 2, 3) + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) assert np.allclose(output_np, torch_output_np) == True # --- Test Case 2: images == False --- - model = TorchWanRMS_norm(2, images=False) - input_shape = (1, 2, 2, 2, 3) + model = TorchWanRMS_norm(dim, images=False) input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() key = jax.random.key(0) rngs = nnx.Rngs(key) - wanrms_norm = WanRMS_norm(dim=2, rngs=rngs, images=False) - input_shape = (1, 2, 2, 2, 3) + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs, images=False) dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) assert np.allclose(output_np, torch_output_np) == True + + def test_zero_padded_conv(self): + import torch + import torch.nn as nn + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + dim = 96 + kernel_size = 3 + stride= (2, 2) + resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride)) + input_shape = (1, 96, 480, 720) + input = torch.ones(input_shape) + output_torch = resample(input) + assert output_torch.shape == (1, 96, 240, 360) + + model = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(1, 3, 3), + stride=(1, 2, 2) + ) + dummy_input = jnp.ones(input_shape) + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) + output = model(dummy_input) + output = jnp.transpose(output, (0, 3, 1, 2)) + assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): batch_size=1 From aeabe27e586cc1288e92229c5cdb34635301ff1e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 17:55:19 +0000 Subject: [PATCH 07/54] wip - test for vae encoder. --- .../models/wan/autoencoder_kl_wan.py | 88 +++++++- src/maxdiffusion/tests/wan_vae_test.py | 196 ++++++++++++++++-- 2 files changed, 256 insertions(+), 28 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 61473ca5b..faac464cc 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -55,14 +55,13 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T class WanCausalConv3d(nnx.Module): def __init__( self, + rngs: nnx.Rngs, # rngs are required for initializing parameters, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], - *, # Mark subsequent arguments as keyword-only stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, use_bias: bool = True, - rngs: nnx.Rngs, # rngs are required for initializing parameters, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, @@ -267,7 +266,13 @@ def __init__( rngs=rngs, ) ) - self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs) + self.time_conv = WanCausalConv3d( + rngs=rngs, + in_channels=dim, + out_channels=dim * 2, + kernel_size=(3, 1, 1), + padding=(1, 0, 0), + ) elif mode == "downsample2d": # TODO - do I need to transpose? self.resample = ZeroPaddedConv2D( @@ -288,6 +293,15 @@ def __init__( self.resample = Identity() def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: + # Input x: (N, D, H, W, C), assume C = self.dim + n, d, h, w, c = x.shape + assert c == self.dim + + x = x.reshape(n*d,h,w,c) + x = self.resample(x) + h_new, w_new, c_new = x.shape[1:] + x = x.reshape(n, d, h_new, w_new, c_new) + return x class WanResidualBlock(nnx.Module): @@ -382,7 +396,13 @@ def __init__( scale = 1.0 # init block - self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1, rngs=rngs) + self.conv_in = WanCausalConv3d( + in_channels=3, + out_channels=dims[0], + kernel_size=3, + padding=1, + rngs=rngs + ) # downsample blocks self.down_blocks = [] @@ -400,11 +420,23 @@ def __init__( self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) # middle_blocks - self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1, rngs=rngs) + self.mid_block = WanMidBlock( + dim=out_dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, + ) # output blocks self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs) - self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=z_dim, + kernel_size=3, + padding=1 + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): return x @@ -487,6 +519,9 @@ def __init__( self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + breakpoint() + x = self.conv_in(x) + breakpoint() return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -514,6 +549,7 @@ def __init__( self.temporal_upsample = temporal_downsample[::-1] self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) + self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1, rngs=rngs) self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs) self.decoder = WanDecoder3d( @@ -539,10 +575,42 @@ def _count_conv3d(module): self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num + def _encode(self, x: jax.Array): + if x.shape[-1] != 3: + # reshape channel last for JAX + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" + + self.clear_cache() + + t = x.shape[1] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + if i == 0: + out = self.encoder( + x[:, :1, :, :, :], + feat_cache=self._enc_feat_map, + feat_ids=self._enc_conv_idx + ) + else: + out_ = self.encoder( + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx + ) + out = jnp.concatenate([out, out_], axis=1) + + enc = self.quant_conv(out) + mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] + enc = jnp.concatenate([mu, logvar], dim=1) + self.clear_cache() + return enc def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """ Encode video into latent distribution.""" - if x.shape[-1] != 3: - raise ValueError(f"Expected input shape (N, D, H, W, 3), got {x.shape}") - - self.clear_cache() \ No newline at end of file + h = self._encode(x) + posterior = FlaxDiagonalGaussianDistribution(h) + if not return_dict: + return (posterior, ) + return FlaxAutoencoderKLOutput(latent_dict=posterior) + \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 3aa64a133..0e5df142f 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -15,6 +15,9 @@ """ import os +import torch +import torch.nn as nn +import torch.nn.functional as F import jax import jax.numpy as jnp from flax import nnx @@ -26,27 +29,15 @@ WanCausalConv3d, WanUpsample, AutoencoderKLWan, + WanEncoder3d, WanRMS_norm, + WanResample, ZeroPaddedConv2D ) -class WanVaeTest(unittest.TestCase): - def setUp(self): - WanVaeTest.dummy_data = {} - - # def test_clear_cache(self): - # key = jax.random.key(0) - # rngs = nnx.Rngs(key) - # wan_vae = AutoencoderKLWan(rngs=rngs) - # wan_vae.clear_cache() +CACHE_T = 2 - def test_wanrms_norm(self): - """Test against the Pytorch implementation""" - import torch - import torch.nn as nn - import torch.nn.functional as F - - class TorchWanRMS_norm(nn.Module): +class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer. @@ -70,6 +61,103 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi def forward(self, x): return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + +class TorchWanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + +class WanVaeTest(unittest.TestCase): + def setUp(self): + WanVaeTest.dummy_data = {} + + # def test_clear_cache(self): + # key = jax.random.key(0) + # rngs = nnx.Rngs(key) + # wan_vae = AutoencoderKLWan(rngs=rngs) + # wan_vae.clear_cache() + + def test_wanrms_norm(self): + """Test against the Pytorch implementation""" # --- Test Case 1: images == True --- dim = 96 @@ -103,8 +191,6 @@ def forward(self, x): assert np.allclose(output_np, torch_output_np) == True def test_zero_padded_conv(self): - import torch - import torch.nn as nn key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -148,6 +234,49 @@ def test_wan_upsample(self): # --- Test Case 1: depth == 1 --- output = upsample(dummy_input) assert output.shape == (1, 1, 64, 64, 3) + + def test_wan_resample(self): + # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + # --- Test Case 1: downsample2d --- + batch = 1 + dim = 96 + t = 1 + h = 480 + w = 720 + mode="downsample2d" + input_shape = (batch, dim, t, h, w) + expected_output_shape = (1, dim, 1, 240, 360) + # output dim should be (1, 96, 1, 480, 720) + dummy_input = torch.ones(input_shape) + torch_wan_resample = TorchWanResample( + dim=dim, + mode=mode + ) + torch_output = torch_wan_resample(dummy_input) + assert torch_output.shape == (batch, dim, t, h // 2, w //2) + + wan_resample = WanResample( + dim, + mode=mode, + rngs=rngs + ) + # channels is always last here + input_shape = (batch, t, h, w, dim) + dummy_input = jnp.ones(input_shape) + output = wan_resample(dummy_input) + assert output.shape == (batch, t, h//2, h//2, dim) + breakpoint() + + # --- Test Case 1: downsample3d --- + dim = 192 + input_shape = (1, dim, 1, 240, 360) + torch_wan_resample = WanResample( + dim=dim, + mode="downsample3d" + ) def test_3d_conv(self): key = jax.random.key(0) @@ -189,5 +318,36 @@ def test_3d_conv(self): output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + def test_wan_encode(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 96 + z_dim = 32 + dim_mult = [1, 2, 4, 4] + num_res_blocks = 2 + attn_scales = [] + temperal_downsample = [False, True, True] + nonlinearity = "silu" + wan_encoder = WanEncoder3d( + rngs=rngs, + dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + non_linearity=nonlinearity + ) + batch = 1 + channels = 3 + t = 49 + height = 480 + width = 720 + input_shape = (batch, channels, t, height, width) + input = jnp.ones(input_shape) + output = wan_encoder(input) + + + if __name__ == "__main__": absltest.main() \ No newline at end of file From 0ec4b02d00ec589db92e0f601bf955fad7c178bf Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 21:34:40 +0000 Subject: [PATCH 08/54] Residual block test --- .../models/wan/autoencoder_kl_wan.py | 140 ++++++++++++++---- src/maxdiffusion/tests/wan_vae_test.py | 70 +++++++-- 2 files changed, 173 insertions(+), 37 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index faac464cc..8f6e2789a 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -313,10 +313,49 @@ def __init__( dropout: float = 0.0, non_linearity: str = "silu", ): - pass + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) + self.conv1 = WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=3, + padding=1 + ) + self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) + self.dropout = nnx.Dropout(dropout, rngs=rngs) + self.conv2 = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=out_dim, + kernel_size=3, + padding=1 + ) + self.conv_shortcut = WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1 + ) if in_dim != out_dim else Identity() + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - return x + # Apply shortcut connection + #breakpoint() + h = self.conv_shortcut(x) + + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.dropout(x) + x = self.conv2(x) + + return x + h class WanAttentionBlock(nnx.Module): def __init__( @@ -397,11 +436,11 @@ def __init__( # init block self.conv_in = WanCausalConv3d( + rngs=rngs, in_channels=3, out_channels=dims[0], kernel_size=3, padding=1, - rngs=rngs ) # downsample blocks @@ -439,6 +478,12 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + # (1, 1, 480, 720, 3) + x = self.conv_in(x) + # (1, 1, 480, 720, 96) + for layer in self.down_blocks: + x = layer(x) + breakpoint() return x class WanDecoder3d(nnx.Module): @@ -480,7 +525,13 @@ def __init__( scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block - self.conv_in = WanCausalConv3d(in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1, rngs=rngs) + self.conv_in = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim, + out_channels=dims[0], + kernel_size=3, + padding=1 + ) # middle_blocks self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) @@ -516,7 +567,13 @@ def __init__( # output blocks self.norm_out = nnx.RMSNorm(num_features=out_dim, ) self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs) - self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs) + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=3, + kernel_size=3, + padding=1 + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): breakpoint() @@ -533,7 +590,7 @@ def __init__( dim_mult: Tuple[int] = [1,2,4,4], num_res_blocks: int = 2, attn_scales: List[float] = [], - temporal_downsample: List[bool] = [False, True, True], + temperal_downsample: List[bool] = [False, True, True], dropout: float = 0.0, latents_mean: List[float] = [ -0.7571,-0.7089,-0.9113,0.1075,-0.1745,0.9653,-0.1517, 1.5508, @@ -545,31 +602,59 @@ def __init__( ], ): self.z_dim = z_dim - self.temporal_downsample = temporal_downsample - self.temporal_upsample = temporal_downsample[::-1] - - self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1) - self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1, rngs=rngs) - self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs) + self.temperal_downsample = temperal_downsample + self.temporal_upsample = temperal_downsample[::-1] - self.decoder = WanDecoder3d( - base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout + self.encoder = WanEncoder3d( + rngs=rngs, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + ) + self.quant_conv = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim * 2, + out_channels=z_dim * 2, + kernel_size=1 + ) + self.post_quant_conv = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim, + out_channels=z_dim, + kernel_size=1, ) + + # self.decoder = WanDecoder3d( + # rngs=rngs, + # dim=base_dim, + # z_dim=z_dim, + # dim_mult=dim_mult, + # num_res_blocks=num_res_blocks, + # attn_scales=attn_scales, + # temperal_upsample=self.temporal_upsample, + # dropout=dropout + # ) self.clear_cache() def clear_cache(self): """ Resets cache dictionaries and indices""" def _count_conv3d(module): count = 0 - node_types = nnx.graph.iter_graph(module, nnx.Module) - for node in node_types: - if isinstance(node.value, WanCausalConv3d): + node_types = nnx.graph.iter_graph([module]) + for path, value in node_types: + #breakpoint() + if isinstance(value, WanCausalConv3d): + print("value: ", value) count +=1 return count - self._conv_num = _count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num + # self._conv_num = _count_conv3d(self.decoder) + # self._conv_idx = [0] + # self._feat_map = [None] * self._conv_num # cache encode self._enc_conv_num = _count_conv3d(self.encoder) self._enc_conv_idx = [0] @@ -581,7 +666,7 @@ def _encode(self, x: jax.Array): x = jnp.transpose(x, (0, 2, 3, 4, 1)) assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" - self.clear_cache() + #self.clear_cache() t = x.shape[1] iter_ = 1 + (t - 1) // 4 @@ -590,7 +675,7 @@ def _encode(self, x: jax.Array): out = self.encoder( x[:, :1, :, :, :], feat_cache=self._enc_feat_map, - feat_ids=self._enc_conv_idx + feat_idx=self._enc_conv_idx ) else: out_ = self.encoder( @@ -600,11 +685,12 @@ def _encode(self, x: jax.Array): ) out = jnp.concatenate([out, out_], axis=1) - enc = self.quant_conv(out) - mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] - enc = jnp.concatenate([mu, logvar], dim=1) - self.clear_cache() - return enc + # enc = self.quant_conv(out) + # mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] + # enc = jnp.concatenate([mu, logvar], dim=1) + # self.clear_cache() + # return enc + return x def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """ Encode video into latent distribution.""" diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 0e5df142f..17dd7c45d 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -30,6 +30,7 @@ WanUpsample, AutoencoderKLWan, WanEncoder3d, + WanResidualBlock, WanRMS_norm, WanResample, ZeroPaddedConv2D @@ -318,6 +319,45 @@ def test_3d_conv(self): output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + def test_wan_residual(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + # one test + in_dim = out_dim = 96 + batch = 1 + t = 1 + height = 480 + width = 720 + dim = 96 + input_shape = (batch, t, height, width, dim) + expected_output_shape = (batch, t, height, width, dim) + + wan_residual_block = WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, + ) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + + # another test + in_dim = 96 + out_dim = 196 + expected_output_shape = (batch, t, height, width, out_dim) + + wan_residual_block = WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, + ) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + + + + def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -328,16 +368,25 @@ def test_wan_encode(self): attn_scales = [] temperal_downsample = [False, True, True] nonlinearity = "silu" - wan_encoder = WanEncoder3d( - rngs=rngs, - dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - non_linearity=nonlinearity + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, ) + # wan_encoder = WanEncoder3d( + # rngs=rngs, + # dim=dim, + # z_dim=z_dim, + # dim_mult=dim_mult, + # num_res_blocks=num_res_blocks, + # attn_scales=attn_scales, + # temperal_downsample=temperal_downsample, + # non_linearity=nonlinearity + # ) batch = 1 channels = 3 t = 49 @@ -345,7 +394,8 @@ def test_wan_encode(self): width = 720 input_shape = (batch, channels, t, height, width) input = jnp.ones(input_shape) - output = wan_encoder(input) + output = wan_vae.encode(input) + breakpoint() From 9b42117db99207e873b38ef2e65b9eca8ae68135 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 23:05:30 +0000 Subject: [PATCH 09/54] add wan vae attention test --- .../models/wan/autoencoder_kl_wan.py | 38 ++++++++++++++++++- src/maxdiffusion/tests/wan_vae_test.py | 25 +++++++++--- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 8f6e2789a..292264c2a 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -364,9 +364,45 @@ def __init__( rngs: nnx.Rngs ): self.dim = dim + self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) + self.to_qkv = nnx.Conv( + in_features=dim, + out_features=dim * 3, + kernel_size=1, + rngs=rngs + ) + self.proj = nnx.Conv( + in_features=dim, + out_features=dim, + kernel_size=1, + rngs=rngs + ) def __call__(self, x: jax.Array): - return x + batch_size, time, height, width, channels = x.shape + identity = x + + x = x.reshape(batch_size * time, height, width, channels) + x = self.norm(x) + + qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) + + qkv = qkv.reshape(batch_size*time, 1, channels * 3, -1) + qkv = jnp.transpose(qkv, (0, 1, 3, 2)) + q, k, v = jnp.split(qkv, 3, axis=-1) + + x = jax.nn.dot_product_attention(q, k, v) + x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) + + #output projection + x = self.proj(x) + + # Reshape back + x = x.reshape(batch_size, time, height, width, channels) + + return x + identity + + class WanMidBlock(nnx.Module): def __init__( diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 17dd7c45d..f91f3932e 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -33,7 +33,8 @@ WanResidualBlock, WanRMS_norm, WanResample, - ZeroPaddedConv2D + ZeroPaddedConv2D, + WanAttentionBlock ) CACHE_T = 2 @@ -322,7 +323,7 @@ def test_3d_conv(self): def test_wan_residual(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - # one test + # --- Test Case 1: same in/out dim --- in_dim = out_dim = 96 batch = 1 t = 1 @@ -341,7 +342,7 @@ def test_wan_residual(self): dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape - # another test + # --- Test Case 1: different in/out dim --- in_dim = 96 out_dim = 196 expected_output_shape = (batch, t, height, width, out_dim) @@ -355,8 +356,22 @@ def test_wan_residual(self): dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape - - + def test_wan_attention(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 384 + batch = 1 + t = 1 + height = 60 + width = 90 + input_shape=(batch, t, height, width, dim) + wan_attention = WanAttentionBlock( + dim=dim, + rngs=rngs + ) + dummy_input = jnp.ones(input_shape) + output = wan_attention(dummy_input) + assert output.shape == input_shape def test_wan_encode(self): key = jax.random.key(0) From 43253250cf6baa9a95be980d55091c67ce16b153 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Apr 2025 23:31:49 +0000 Subject: [PATCH 10/54] add wan mid block vae test --- .../models/wan/autoencoder_kl_wan.py | 14 ++++++++++ src/maxdiffusion/tests/wan_vae_test.py | 27 ++++++++++++------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 292264c2a..f9876e9c8 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -414,8 +414,20 @@ def __init__( num_layers: int = 1 ): self.dim = dim + resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) + resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)) + self.attentions = attentions + self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + x = self.resnets[0](x) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + x = resnet(x) return x class WanUpBlock(nnx.Module): @@ -519,6 +531,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): # (1, 1, 480, 720, 96) for layer in self.down_blocks: x = layer(x) + + x = self.mid_block(x) breakpoint() return x diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index f91f3932e..5cb799cdc 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -30,6 +30,7 @@ WanUpsample, AutoencoderKLWan, WanEncoder3d, + WanMidBlock, WanResidualBlock, WanRMS_norm, WanResample, @@ -373,6 +374,22 @@ def test_wan_attention(self): output = wan_attention(dummy_input) assert output.shape == input_shape + def test_wan_midblock(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch = 1 + t = 1 + dim = 384 + height = 60 + width = 90 + input_shape = (batch, t, height, width, dim) + wan_midblock = WanMidBlock( + dim=dim, rngs=rngs + ) + dummy_input = jnp.ones(input_shape) + output = wan_midblock(dummy_input) + assert output.shape == input_shape + def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -392,16 +409,6 @@ def test_wan_encode(self): attn_scales=attn_scales, temperal_downsample=temperal_downsample, ) - # wan_encoder = WanEncoder3d( - # rngs=rngs, - # dim=dim, - # z_dim=z_dim, - # dim_mult=dim_mult, - # num_res_blocks=num_res_blocks, - # attn_scales=attn_scales, - # temperal_downsample=temperal_downsample, - # non_linearity=nonlinearity - # ) batch = 1 channels = 3 t = 49 From a5e1e95b4c66185bedbabf02f2e2cdb7602da0d0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 28 Apr 2025 20:28:02 +0000 Subject: [PATCH 11/54] finishes vae encoder with matching shapes --- .../models/wan/autoencoder_kl_wan.py | 87 ++++++++++++++----- src/maxdiffusion/tests/wan_vae_test.py | 17 ++-- 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index f9876e9c8..a5678d028 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -289,6 +289,14 @@ def __init__( kernel_size=(1, 3, 3), stride=(1, 2, 2) ) + self.time_conv = WanCausalConv3d( + rngs=rngs, + in_channels = dim, + out_channels = dim, + kernel_size=(3, 1, 1), + stride=(2, 1, 1), + padding= (0, 0, 0) + ) else: self.resample = Identity() @@ -302,6 +310,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: h_new, w_new, c_new = x.shape[1:] x = x.reshape(n, d, h_new, w_new, c_new) + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = jnp.copy(x) + feat_idx[0] +=1 + else: + cache_x = jnp.copy(x[:, -1:, :, :, :]) + x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x class WanResidualBlock(nnx.Module): @@ -343,7 +363,6 @@ def __init__( def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): # Apply shortcut connection - #breakpoint() h = self.conv_shortcut(x) x = self.norm1(x) @@ -505,7 +524,8 @@ def __init__( if i != len(dim_mult) - 1: mode = "downsample3d" if temperal_downsample[i] else "downsample2d" self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) - + scale /= 2.0 + # middle_blocks self.mid_block = WanMidBlock( dim=out_dim, @@ -516,7 +536,12 @@ def __init__( ) # output blocks - self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs) + self.norm_out = WanRMS_norm( + out_dim, + channel_first=False, + images=False, + rngs=rngs + ) self.conv_out = WanCausalConv3d( rngs=rngs, in_channels=out_dim, @@ -526,14 +551,39 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - # (1, 1, 480, 720, 3) - x = self.conv_in(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv_in(x) # (1, 1, 480, 720, 96) for layer in self.down_blocks: - x = layer(x) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) - x = self.mid_block(x) - breakpoint() + x = self.mid_block(x, feat_cache, feat_idx) + + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv_out(x) return x class WanDecoder3d(nnx.Module): @@ -626,9 +676,7 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - breakpoint() x = self.conv_in(x) - breakpoint() return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -696,9 +744,7 @@ def _count_conv3d(module): count = 0 node_types = nnx.graph.iter_graph([module]) for path, value in node_types: - #breakpoint() if isinstance(value, WanCausalConv3d): - print("value: ", value) count +=1 return count @@ -711,6 +757,7 @@ def _count_conv3d(module): self._enc_feat_map = [None] * self._enc_conv_num def _encode(self, x: jax.Array): + self.clear_cache() if x.shape[-1] != 3: # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) @@ -721,6 +768,7 @@ def _encode(self, x: jax.Array): t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): + self._enc_conv_idx = [0] if i == 0: out = self.encoder( x[:, :1, :, :, :], @@ -729,18 +777,17 @@ def _encode(self, x: jax.Array): ) else: out_ = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx ) out = jnp.concatenate([out, out_], axis=1) - - # enc = self.quant_conv(out) - # mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] - # enc = jnp.concatenate([mu, logvar], dim=1) - # self.clear_cache() + enc = self.quant_conv(out) + mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] + enc = jnp.concatenate([mu, logvar], axis=-1) + self.clear_cache() # return enc - return x + return enc def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """ Encode video into latent distribution.""" @@ -748,5 +795,5 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode posterior = FlaxDiagonalGaussianDistribution(h) if not return_dict: return (posterior, ) - return FlaxAutoencoderKLOutput(latent_dict=posterior) + return FlaxAutoencoderKLOutput(latent_dist=posterior) \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 5cb799cdc..3a69aeb03 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -153,11 +153,11 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} - # def test_clear_cache(self): - # key = jax.random.key(0) - # rngs = nnx.Rngs(key) - # wan_vae = AutoencoderKLWan(rngs=rngs) - # wan_vae.clear_cache() + def test_clear_cache(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wan_vae = AutoencoderKLWan(rngs=rngs) + wan_vae.clear_cache() def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -394,12 +394,11 @@ def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dim = 96 - z_dim = 32 + z_dim = 16 dim_mult = [1, 2, 4, 4] num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - nonlinearity = "silu" wan_vae = AutoencoderKLWan( rngs=rngs, base_dim=dim, @@ -417,9 +416,7 @@ def test_wan_encode(self): input_shape = (batch, channels, t, height, width) input = jnp.ones(input_shape) output = wan_vae.encode(input) - breakpoint() - - + assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) if __name__ == "__main__": absltest.main() \ No newline at end of file From efe8528b02836723a183e8e43141c34a5c6ddf0d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 28 Apr 2025 20:55:07 +0000 Subject: [PATCH 12/54] add cache logic to modules. --- .../models/wan/autoencoder_kl_wan.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index a5678d028..ade16fe53 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -367,12 +367,33 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm1(x) x = self.nonlinearity(x) - x = self.conv1(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] <2 and feat_cache[idx] is not None: + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv1(x) x = self.norm2(x) x = self.nonlinearity(x) x = self.dropout(x) - x = self.conv2(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] <2 and feat_cache[idx] is not None: + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] +=1 + else: + x = self.conv2(x) return x + h @@ -442,7 +463,7 @@ def __init__( self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - x = self.resnets[0](x) + x = self.resnets[0](x, feat_cache, feat_idx) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: x = attn(x) From cf68754ef7efeb4263a887baa8c625e8dd370027 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 29 Apr 2025 18:32:23 +0000 Subject: [PATCH 13/54] adds decoder and checks matching resolutions. --- .../models/wan/autoencoder_kl_wan.py | 186 +++++++++++++----- src/maxdiffusion/tests/wan_vae_test.py | 33 ++++ 2 files changed, 174 insertions(+), 45 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index ade16fe53..060bda017 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -22,7 +22,11 @@ from ...configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ... import common_types -from ..vae_flax import FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution +from ..vae_flax import ( + FlaxAutoencoderKLOutput, + FlaxDiagonalGaussianDistribution, + FlaxDecoderOutput +) BlockSizes = common_types.BlockSizes @@ -82,7 +86,7 @@ def __init__( (0, 0) # Channel dimension - no padding ) - # Store the amount of padding needed *before* the depth dimension for caching logoic + # Store the amount of padding needed *before* the depth dimension for caching logic self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] self.conv = nnx.Conv( @@ -103,7 +107,6 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Arr # Ensure cache has same spatial/channel dims, potentially different depth assert cache_x.shape[0] == x.shape[0] and \ cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" - cache_len = cache_x.shape[1] x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) @@ -166,24 +169,13 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'): def __call__(self, x: jax.Array) -> jax.Array: input_dtype = x.dtype in_shape = x.shape - is_3d = len(in_shape) == 5 - n, d, h, w, c = in_shape if is_3d else(in_shape[0], 1, in_shape[1], in_shape[2], in_shape[3]) - + assert len(in_shape) == 4, "This module only takes tensors with shape of 4." + n, h, w, c = in_shape target_h = int(h * self.scale_factor[0]) target_w = int(w * self.scale_factor[1]) - - # jax.image.resize expects (..., H, W, C) - if is_3d: - x_reshaped = x.reshape(n * d, h, w, c) - out_reshaped = jax.image.resize(x_reshaped.astype(jnp.float32), - (n * d, target_h, target_w, c), - method=self.method) - out = out_reshaped.reshape(n, d, target_h, target_w, c) - else: # Asumming (N, H, W, C) - out = jax.image.resize(x.astype(jnp.float32), - (n, target_h, target_w, c), - method=self.method) - + out = jax.image.resize(x.astype(jnp.float32), + (n, target_h, target_w, c), + method=self.method) return out.astype(input_dtype) class Identity(nnx.Module): @@ -256,7 +248,7 @@ def __init__( ) elif mode == "upsample3d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), nnx.Conv( dim, dim // 2, @@ -305,6 +297,29 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: n, d, h, w, c = x.shape assert c == self.dim + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], dim=1) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(n, 2, d, h, w, c) + x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2) + x = x.reshape(n, d*2, h, w, c) + d = x.shape[1] x = x.reshape(n*d,h,w,c) x = self.resample(x) h_new, w_new, c_new = x.shape[1:] @@ -371,7 +386,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) - if cache_x.shape[1] <2 and feat_cache[idx] is not None: + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv1(x, feat_cache[idx]) @@ -387,7 +402,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) - if cache_x.shape[1] <2 and feat_cache[idx] is not None: + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x @@ -458,7 +473,7 @@ def __init__( attentions = [] for _ in range(num_layers): attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) - resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)) + resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)) self.attentions = attentions self.resnets = resnets @@ -467,7 +482,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: x = attn(x) - x = resnet(x) + x = resnet(x, feat_cache, feat_idx) return x class WanUpBlock(nnx.Module): @@ -482,19 +497,31 @@ def __init__( non_linearity: str = "silu" ): # Create layers list - self.resnets = [] + resnets = [] # Add residual blocks and attention if needed current_dim = in_dim for _ in range(num_res_blocks + 1): - self.resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) + resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) current_dim = out_dim + self.resnets = resnets # Add upsampling layer if needed. self.upsamplers = None if upsample_mode is not None: - self.upsamplers = WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs) + self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)] def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) return x class WanEncoder3d(nnx.Module): @@ -655,7 +682,13 @@ def __init__( ) # middle_blocks - self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) + self.mid_block = WanMidBlock( + dim=dims[0], + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1 + ) # upsample blocks self.up_blocks = [] @@ -668,7 +701,6 @@ def __init__( upsample_mode = None if i != len(dim_mult) - 1: upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - # Crete and add the upsampling block up_block = WanUpBlock( in_dim=in_dim, @@ -686,8 +718,7 @@ def __init__( scale *=2.0 # output blocks - self.norm_out = nnx.RMSNorm(num_features=out_dim, ) - self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs) + self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) self.conv_out = WanCausalConv3d( rngs=rngs, in_channels=out_dim, @@ -697,7 +728,39 @@ def __init__( ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - x = self.conv_in(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T: , :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + # cache last frame of the last two chunk + cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) return x class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -723,6 +786,8 @@ def __init__( self.z_dim = z_dim self.temperal_downsample = temperal_downsample self.temporal_upsample = temperal_downsample[::-1] + self.latents_mean = latents_mean + self.latents_std = latents_std self.encoder = WanEncoder3d( rngs=rngs, @@ -747,16 +812,16 @@ def __init__( kernel_size=1, ) - # self.decoder = WanDecoder3d( - # rngs=rngs, - # dim=base_dim, - # z_dim=z_dim, - # dim_mult=dim_mult, - # num_res_blocks=num_res_blocks, - # attn_scales=attn_scales, - # temperal_upsample=self.temporal_upsample, - # dropout=dropout - # ) + self.decoder = WanDecoder3d( + rngs=rngs, + dim=base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temporal_upsample, + dropout=dropout + ) self.clear_cache() def clear_cache(self): @@ -769,9 +834,9 @@ def _count_conv3d(module): count +=1 return count - # self._conv_num = _count_conv3d(self.decoder) - # self._conv_idx = [0] - # self._feat_map = [None] * self._conv_num + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num # cache encode self._enc_conv_num = _count_conv3d(self.encoder) self._enc_conv_idx = [0] @@ -817,4 +882,35 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode if not return_dict: return (posterior, ) return FlaxAutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: + self.clear_cache() + iter_ = z.shape[1] + x = self.post_quant_conv(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + + out = jnp.concatenate([out, out_], axis=1) + + out = jnp.clip(out, a_min=-1.0, a_max=1.0) + self.clear_cache() + if not return_dict: + return (out, ) + + return FlaxDecoderOutput(sample=out) + + def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: + if z.shape[-1] != self.z_dim: + # reshape channel last for JAX + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + assert x.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {x.shape}" + decoded = self._decode(z).sample + if not return_dict: + return (decoded,) + return FlaxDecoderOutput(sample=decoded) + \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 3a69aeb03..a2dd50bff 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -390,6 +390,39 @@ def test_wan_midblock(self): output = wan_midblock(dummy_input) assert output.shape == input_shape + def test_wan_decode(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dim = 96 + z_dim = 16 + dim_mult = [1, 2, 4, 4] + num_res_blocks = 2 + attn_scales = [] + temperal_downsample = [False, True, True] + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + ) + + batch = 1 + t = 13 + channels = 16 + height = 60 + width = 90 + input_shape = (batch, t, height, width, channels) + input = jnp.ones(input_shape) + + latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) + latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) + input = input / latents_std + latents_mean + dummy_output = wan_vae.decode(input) + assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) From cd16f28b7be3db0ee0b4af531046db9601277997 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 29 Apr 2025 18:36:10 +0000 Subject: [PATCH 14/54] run linter --- src/maxdiffusion/generate_wan.py | 235 +++--- src/maxdiffusion/image_processor.py | 71 +- src/maxdiffusion/models/attention_flax.py | 17 +- src/maxdiffusion/models/embeddings_flax.py | 1 + .../models/wan/autoencoder_kl_wan.py | 686 ++++++++---------- .../wan/transformers/transformer_flux_wan.py | 169 ++--- .../transformers/transformer_flux_wan_nnx.py | 71 +- .../pipelines/wan/pipeline_wan.py | 42 +- ...heduling_flow_match_euler_discrete_flax.py | 39 +- src/maxdiffusion/tests/wan_vae_test.py | 351 +++++---- src/maxdiffusion/video_processor.py | 156 ++-- 11 files changed, 877 insertions(+), 961 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e19d783de..3a79621d3 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import html from typing import Callable, List, Union, Sequence, Optional import time @@ -37,21 +38,23 @@ setup_initial_state, ) + def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text + text = whitespace_clean(basic_clean(text)) + return text + def _get_t5_prompt_embeds( tokenizer: AutoTokenizer, @@ -63,35 +66,36 @@ def _get_t5_prompt_embeds( dtype: Optional[torch.dtype] = None, ): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - batch_size = len(prompt) + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - - prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + return prompt_embeds - return prompt_embeds def encode_prompt( tokenizer: AutoTokenizer, @@ -106,77 +110,77 @@ def encode_prompt( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - """ - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - return prompt_embeds, negative_prompt_embeds + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + def run(config): max_logging.log("Wan 2.1 inference script") @@ -188,17 +192,15 @@ def run(config): global_batch_size = config.per_device_batch_size * jax.local_device_count() tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype + config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype ) text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="text_encoder", + config.pretrained_model_name_or_path, + subfolder="text_encoder", ) s0 = time.perf_counter() prompt_embeds, negative_prompt_embeds = encode_prompt( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=config.prompt, - negative_prompt=config.negative_prompt + tokenizer=tokenizer, text_encoder=text_encoder, prompt=config.prompt, negative_prompt=config.negative_prompt ) max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") @@ -209,20 +211,15 @@ def run(config): # ) # breakpoint() - pipeline, params = WanPipeline.from_pretrained( - config.pretrained_model_name_or_path, - vae=None, - transformer=None - ) - - #wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) - + pipeline, params = WanPipeline.from_pretrained(config.pretrained_model_name_or_path, vae=None, transformer=None) + # wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) + if __name__ == "__main__": app.run(main) diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 2f99d0f7f..76fa7635e 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -35,51 +35,52 @@ List[torch.FloatTensor], ] + def is_valid_image(image) -> bool: - r""" - Checks if the input is a valid image. + r""" + Checks if the input is a valid image. - A valid image can be: - - A `PIL.Image.Image`. - - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). - Args: - image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): - The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. - Returns: - `bool`: - `True` if the input is a valid image, `False` otherwise. - """ - return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) def is_valid_image_imagelist(images): - r""" - Checks if the input is a valid image or list of images. + r""" + Checks if the input is a valid image or list of images. - The input can be one of the following formats: - - A 4D tensor or numpy array (batch of images). - - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or - `torch.Tensor`. - - A list of valid images. + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. - Args: - images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): - The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid - images. + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. - Returns: - `bool`: - `True` if the input is valid, `False` otherwise. - """ - if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: - return True - elif is_valid_image(images): - return True - elif isinstance(images, list): - return all(is_valid_image(image) for image in images) - return False + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False class VaeImageProcessor(ConfigMixin): diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index fc838c9db..2f8946056 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -405,6 +405,7 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back + def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) @@ -414,6 +415,7 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + class FlaxWanAttention(nn.Module): query_dim: int heads: int = 8 @@ -433,7 +435,7 @@ class FlaxWanAttention(nn.Module): out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) precision: jax.lax.Precision = None qkv_bias: bool = False - + def setup(self): if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") @@ -509,13 +511,13 @@ def setup(self): precision=self.precision, ) self.dropout_layer = nn.Dropout(rate=self.dropout) - + def call( - self, - hidden_states: Array, - encoder_hidden_states: Optional[Array], - rotary_emb: Optional[Array], - deterministic: bool = True + self, + hidden_states: Array, + encoder_hidden_states: Optional[Array], + rotary_emb: Optional[Array], + deterministic: bool = True, ): encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states @@ -539,6 +541,7 @@ def call( hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) return self.dropout_layer(hidden_states, deterministic=deterministic) + class FlaxFluxAttention(nn.Module): query_dim: int heads: int = 8 diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 3377ef6cc..cc961e131 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -56,6 +56,7 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal + class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 060bda017..2cf43f69c 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -22,23 +22,14 @@ from ...configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ... import common_types -from ..vae_flax import ( - FlaxAutoencoderKLOutput, - FlaxDiagonalGaussianDistribution, - FlaxDecoderOutput -) +from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) BlockSizes = common_types.BlockSizes CACHE_T = 2 -_ACTIVATIONS = { - "swish": jax.nn.silu, - "silu": jax.nn.silu, - "relu": jax.nn.relu, - "gelu": jax.nn.gelu, - "mish": jax.nn.mish -} +_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish} + def get_activation(name: str): func = _ACTIVATIONS.get(name) @@ -46,111 +37,115 @@ def get_activation(name: str): raise ValueError(f"Unknown activation function: {name}") return func + # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: - """Canonicalizes a value to a tuple of integers.""" - if isinstance(x, int): - return (x,) * rank - elif isinstance(x, Sequence) and len(x) == rank: - return tuple(x) - else: - raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") + """Canonicalizes a value to a tuple of integers.""" + if isinstance(x, int): + return (x,) * rank + elif isinstance(x, Sequence) and len(x) == rank: + return tuple(x) + else: + raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") + class WanCausalConv3d(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, # rngs are required for initializing parameters, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, - padding: Union[int, Tuple[int, int, int]] = 0, - use_bias: bool = True, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + rngs: nnx.Rngs, # rngs are required for initializing parameters, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + use_bias: bool = True, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): - self.kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') - self.stride = _canonicalize_tuple(stride, 3, 'stride') - padding_tuple = _canonicalize_tuple(padding, 3, 'padding') # (D, H, W) padding amounts + self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") + self.stride = _canonicalize_tuple(stride, 3, "stride") + padding_tuple = _canonicalize_tuple(padding, 3, "padding") # (D, H, W) padding amounts self._causal_padding = ( - (0, 0), # Batch dimension - no padding - (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) - (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding - (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding - (0, 0) # Channel dimension - no padding + (0, 0), # Batch dimension - no padding + (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) + (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding + (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding + (0, 0), # Channel dimension - no padding ) # Store the amount of padding needed *before* the depth dimension for caching logic - self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] self.conv = nnx.Conv( - in_features=in_channels, - out_features=out_channels, - kernel_size=self.kernel_size, - strides=self.stride, - use_bias=use_bias, - padding='VALID', # Handle padding manually - rngs=rngs + in_features=in_channels, + out_features=out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + use_bias=use_bias, + padding="VALID", # Handle padding manually + rngs=rngs, ) - + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Array: - current_padding = list(self._causal_padding) # Mutable copy + current_padding = list(self._causal_padding) # Mutable copy padding_needed = self._depth_padding_before if cache_x is not None and padding_needed > 0: # Ensure cache has same spatial/channel dims, potentially different depth - assert cache_x.shape[0] == x.shape[0] and \ - cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" + assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" cache_len = cache_x.shape[1] - x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) + x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) padding_needed -= cache_len if padding_needed < 0: # Cache longer than needed padding, trim from start x = x[:, -padding_needed:, ...] - current_padding[1] = (0, 0) # No explicit padding needed now + current_padding[1] = (0, 0) # No explicit padding needed now else: # Update depth padding needed current_padding[1] = (padding_needed, 0) - + # Apply padding if any dimension requires it padding_to_apply = tuple(current_padding) if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): - x_padded = jnp.pad(x, padding_to_apply, mode='constant', constant_values=0.0) + x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: x_padded = x - + out = self.conv(x_padded) return out + class WanRMS_norm(nnx.Module): + def __init__( - self, - dim: int, - rngs: nnx.Rngs, - channel_first: bool = True, - images: bool = True, - eps: float = 1e-6, - use_bias: bool = False, + self, + dim: int, + rngs: nnx.Rngs, + channel_first: bool = True, + images: bool = True, + eps: float = 1e-6, + use_bias: bool = False, ): broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim, ) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.eps = eps self.channel_first = channel_first - self.scale = dim ** 0.5 + self.scale = dim**0.5 # Initialize gamma as parameter self.gamma = nnx.Param(jnp.ones(shape)) if use_bias: self.bias = nnx.Param(jnp.zeros(shape)) else: self.bias = 0 - + def __call__(self, x: jax.Array) -> jax.Array: normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True) normalized = x / jnp.maximum(normalized, self.eps) @@ -159,13 +154,15 @@ def __call__(self, x: jax.Array) -> jax.Array: return normalized + self.bias.value return normalized + class WanUpsample(nnx.Module): - def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'): + + def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"): # scale_factor for (H, W) # JAX resize works on spatial dims, H, W assumming (N, D, H, W, C) or (N, H, W, C) self.scale_factor = scale_factor self.method = method - + def __call__(self, x: jax.Array) -> jax.Array: input_dtype = x.dtype in_shape = x.shape @@ -173,62 +170,58 @@ def __call__(self, x: jax.Array) -> jax.Array: n, h, w, c = in_shape target_h = int(h * self.scale_factor[0]) target_w = int(w * self.scale_factor[1]) - out = jax.image.resize(x.astype(jnp.float32), - (n, target_h, target_w, c), - method=self.method) + out = jax.image.resize(x.astype(jnp.float32), (n, target_h, target_w, c), method=self.method) return out.astype(input_dtype) + class Identity(nnx.Module): + def __call__(self, x): return x + class ZeroPaddedConv2D(nnx.Module): """ Module for adding padding before conv. Currently it does not add any padding. """ + def __init__( - self, - dim: int, - rngs: nnx.Rngs, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + dim: int, + rngs: nnx.Rngs, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): - kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size') - stride = _canonicalize_tuple(stride, 3, 'stride') - self.conv = nnx.Conv( - dim, - dim, - kernel_size=kernel_size, - strides=stride, - use_bias=True, - rngs=rngs - ) - + kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") + stride = _canonicalize_tuple(stride, 3, "stride") + self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) + def __call__(self, x): return self.conv(x) class WanResample(nnx.Module): + def __init__( - self, - dim: int, - mode: str, - rngs: nnx.Rngs, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + dim: int, + mode: str, + rngs: nnx.Rngs, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): self.dim = dim self.mode = mode @@ -236,58 +229,43 @@ def __init__( if mode == "upsample2d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), - nnx.Conv( - dim, - dim // 2, - kernel_size=(1, 3, 3), - padding='SAME', - use_bias=True, - rngs=rngs, - ) + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding="SAME", + use_bias=True, + rngs=rngs, + ), ) elif mode == "upsample3d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), - nnx.Conv( - dim, - dim // 2, - kernel_size=(1, 3, 3), - padding='SAME', - use_bias=True, - rngs=rngs, - ) + WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), + nnx.Conv( + dim, + dim // 2, + kernel_size=(1, 3, 3), + padding="SAME", + use_bias=True, + rngs=rngs, + ), ) self.time_conv = WanCausalConv3d( - rngs=rngs, - in_channels=dim, - out_channels=dim * 2, - kernel_size=(3, 1, 1), - padding=(1, 0, 0), + rngs=rngs, + in_channels=dim, + out_channels=dim * 2, + kernel_size=(3, 1, 1), + padding=(1, 0, 0), ) elif mode == "downsample2d": # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(1, 3, 3), - stride=(1, 2, 2) - ) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) elif mode == "downsample3d": # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(1, 3, 3), - stride=(1, 2, 2) - ) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) self.time_conv = WanCausalConv3d( - rngs=rngs, - in_channels = dim, - out_channels = dim, - kernel_size=(3, 1, 1), - stride=(2, 1, 1), - padding= (0, 0, 0) + rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) ) else: self.resample = Identity() @@ -318,9 +296,9 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: feat_idx[0] += 1 x = x.reshape(n, 2, d, h, w, c) x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2) - x = x.reshape(n, d*2, h, w, c) + x = x.reshape(n, d * 2, h, w, c) d = x.shape[1] - x = x.reshape(n*d,h,w,c) + x = x.reshape(n * d, h, w, c) x = self.resample(x) h_new, w_new, c_new = x.shape[1:] x = x.reshape(n, d, h_new, w_new, c_new) @@ -330,7 +308,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: idx = feat_idx[0] if feat_cache[idx] is None: feat_cache[idx] = jnp.copy(x) - feat_idx[0] +=1 + feat_idx[0] += 1 else: cache_x = jnp.copy(x[:, -1:, :, :, :]) x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) @@ -338,8 +316,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: feat_idx[0] += 1 return x - + + class WanResidualBlock(nnx.Module): + def __init__( self, in_dim: int, @@ -352,29 +332,15 @@ def __init__( # layers self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) - self.conv1 = WanCausalConv3d( - rngs=rngs, - in_channels=in_dim, - out_channels=out_dim, - kernel_size=3, - padding=1 - ) + self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1) self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) self.dropout = nnx.Dropout(dropout, rngs=rngs) - self.conv2 = WanCausalConv3d( - rngs=rngs, - in_channels=out_dim, - out_channels=out_dim, - kernel_size=3, - padding=1 + self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.conv_shortcut = ( + WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1) + if in_dim != out_dim + else Identity() ) - self.conv_shortcut = WanCausalConv3d( - rngs=rngs, - in_channels=in_dim, - out_channels=out_dim, - kernel_size=1 - ) if in_dim != out_dim else Identity() - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): # Apply shortcut connection @@ -391,7 +357,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv1(x) @@ -406,50 +372,38 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv2(x) return x + h + class WanAttentionBlock(nnx.Module): - def __init__( - self, - dim: int, - rngs: nnx.Rngs - ): + + def __init__(self, dim: int, rngs: nnx.Rngs): self.dim = dim self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) - self.to_qkv = nnx.Conv( - in_features=dim, - out_features=dim * 3, - kernel_size=1, - rngs=rngs - ) - self.proj = nnx.Conv( - in_features=dim, - out_features=dim, - kernel_size=1, - rngs=rngs - ) - + self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=1, rngs=rngs) + self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=1, rngs=rngs) + def __call__(self, x: jax.Array): batch_size, time, height, width, channels = x.shape identity = x - + x = x.reshape(batch_size * time, height, width, channels) x = self.norm(x) - qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) + qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - qkv = qkv.reshape(batch_size*time, 1, channels * 3, -1) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) q, k, v = jnp.split(qkv, 3, axis=-1) x = jax.nn.dot_product_attention(q, k, v) x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) - #output projection + # output projection x = self.proj(x) # Reshape back @@ -458,25 +412,18 @@ def __call__(self, x: jax.Array): return x + identity - class WanMidBlock(nnx.Module): - def __init__( - self, - dim: int, - rngs: nnx.Rngs, - dropout: float = 0.0, - non_linearity: str = "silu", - num_layers: int = 1 - ): + + def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): self.dim = dim - resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)] + resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)] attentions = [] for _ in range(num_layers): attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)) self.attentions = attentions self.resnets = resnets - + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.resnets[0](x, feat_cache, feat_idx) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -485,38 +432,42 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = resnet(x, feat_cache, feat_idx) return x + class WanUpBlock(nnx.Module): + def __init__( - self, - in_dim: int, - out_dim: int, - num_res_blocks: int, - rngs: nnx.Rngs, - dropout: float = 0.0, - upsample_mode: Optional[str] = None, - non_linearity: str = "silu" + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", ): # Create layers list resnets = [] # Add residual blocks and attention if needed current_dim = in_dim for _ in range(num_res_blocks + 1): - resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs)) + resnets.append( + WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs) + ) current_dim = out_dim self.resnets = resnets - + # Add upsampling layer if needed. self.upsamplers = None if upsample_mode is not None: self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)] - + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): for resnet in self.resnets: if feat_cache is not None: x = resnet(x, feat_cache, feat_idx) else: x = resnet(x) - + if self.upsamplers is not None: if feat_cache is not None: x = self.upsamplers[0](x, feat_cache, feat_idx) @@ -524,18 +475,20 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.upsamplers[0](x) return x + class WanEncoder3d(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, - dim: int =128, - z_dim: int = 4, - dim_mult = [1, 2, 4, 4], - num_res_blocks = 2, - attn_scales = [], - temperal_downsample = [True, True, False], - dropout = 0.0, - non_linearity: str = 'silu', + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", ): self.dim = dim self.z_dim = z_dim @@ -551,11 +504,11 @@ def __init__( # init block self.conv_in = WanCausalConv3d( - rngs=rngs, - in_channels=3, - out_channels=dims[0], - kernel_size=3, - padding=1, + rngs=rngs, + in_channels=3, + out_channels=dims[0], + kernel_size=3, + padding=1, ) # downsample blocks @@ -567,37 +520,26 @@ def __init__( if scale in attn_scales: self.down_blocks.append(WanAttentionBlock(dim=out_dim, rngs=rngs)) in_dim = out_dim - + # downsample block if i != len(dim_mult) - 1: mode = "downsample3d" if temperal_downsample[i] else "downsample2d" self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) scale /= 2.0 - + # middle_blocks self.mid_block = WanMidBlock( - dim=out_dim, - rngs=rngs, - dropout=dropout, - non_linearity=non_linearity, - num_layers=1, + dim=out_dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, ) # output blocks - self.norm_out = WanRMS_norm( - out_dim, - channel_first=False, - images=False, - rngs=rngs - ) - self.conv_out = WanCausalConv3d( - rngs=rngs, - in_channels=out_dim, - out_channels=z_dim, - kernel_size=3, - padding=1 - ) - + self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs) + self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] @@ -607,7 +549,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_in(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv_in(x) # (1, 1, 480, 720, 96) @@ -616,7 +558,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = layer(x, feat_cache, feat_idx) else: x = layer(x) - + x = self.mid_block(x, feat_cache, feat_idx) x = self.norm_out(x) @@ -629,11 +571,12 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_out(x, feat_cache[idx]) feat_cache[idx] = cache_x - feat_idx[0] +=1 + feat_idx[0] += 1 else: x = self.conv_out(x) return x + class WanDecoder3d(nnx.Module): r""" A 3D decoder module. @@ -647,17 +590,18 @@ class WanDecoder3d(nnx.Module): dropout (float): Dropout rate for the dropout layers. non_linearity (str): Type of non-linearity to use. """ + def __init__( - self, - rngs: nnx.Rngs, - dim: int = 128, - z_dim: int = 128, - dim_mult: List[int] = [1, 2, 4, 4], - num_res_blocks: int = 2, - attn_scales = List[float], - temperal_upsample=[False, True, True], - dropout=0.0, - non_linearity: str = "silu", + self, + rngs: nnx.Rngs, + dim: int = 128, + z_dim: int = 128, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales=List[float], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", ): self.dim = dim self.z_dim = z_dim @@ -673,22 +617,10 @@ def __init__( scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block - self.conv_in = WanCausalConv3d( - rngs=rngs, - in_channels=z_dim, - out_channels=dims[0], - kernel_size=3, - padding=1 - ) + self.conv_in = WanCausalConv3d(rngs=rngs, in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1) # middle_blocks - self.mid_block = WanMidBlock( - dim=dims[0], - rngs=rngs, - dropout=dropout, - non_linearity=non_linearity, - num_layers=1 - ) + self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) # upsample blocks self.up_blocks = [] @@ -696,41 +628,35 @@ def __init__( # residual (+attention) blocks if i > 0: in_dim = in_dim // 2 - + # Determine if we need upsampling upsample_mode = None if i != len(dim_mult) - 1: upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" # Crete and add the upsampling block up_block = WanUpBlock( - in_dim=in_dim, - out_dim=out_dim, - num_res_blocks=num_res_blocks, - dropout=dropout, - upsample_mode=upsample_mode, - non_linearity=non_linearity, - rngs=rngs + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + rngs=rngs, ) self.up_blocks.append(up_block) # Update scale for next iteration if upsample_mode is not None: - scale *=2.0 - + scale *= 2.0 + # output blocks self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) - self.conv_out = WanCausalConv3d( - rngs=rngs, - in_channels=out_dim, - out_channels=3, - kernel_size=3, - padding=1 - ) - + self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=3, kernel_size=3, padding=1) + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - cache_x = jnp.copy(x[:, -CACHE_T: , :, :, :]) + cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) @@ -739,10 +665,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_in(x) - + ## middle x = self.mid_block(x, feat_cache, feat_idx) - + ## upsamples for up_block in self.up_blocks: x = up_block(x, feat_cache, feat_idx) @@ -763,25 +689,55 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.conv_out(x) return x + class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( - self, - rngs: nnx.Rngs, - base_dim: int = 96, - z_dim: int = 16, - dim_mult: Tuple[int] = [1,2,4,4], - num_res_blocks: int = 2, - attn_scales: List[float] = [], - temperal_downsample: List[bool] = [False, True, True], - dropout: float = 0.0, - latents_mean: List[float] = [ - -0.7571,-0.7089,-0.9113,0.1075,-0.1745,0.9653,-0.1517, 1.5508, - 0.4134,-0.0715,0.5517,-0.3632,-0.1922,-0.9497,0.2503,-0.2921, - ], - latents_std: List[float] = [ - 2.8184,1.4541,2.3275,2.6558,1.2196,1.7708,2.6052,2.0743, - 3.2687,2.1526,2.8652,1.5579,1.6382,1.1253,2.8251,1.9160, - ], + self, + rngs: nnx.Rngs, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], ): self.z_dim = z_dim self.temperal_downsample = temperal_downsample @@ -790,48 +746,44 @@ def __init__( self.latents_std = latents_std self.encoder = WanEncoder3d( - rngs=rngs, - dim=base_dim, - z_dim=z_dim * 2, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - dropout=dropout, - ) - self.quant_conv = WanCausalConv3d( - rngs=rngs, - in_channels=z_dim * 2, - out_channels=z_dim * 2, - kernel_size=1 + rngs=rngs, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, ) + self.quant_conv = WanCausalConv3d(rngs=rngs, in_channels=z_dim * 2, out_channels=z_dim * 2, kernel_size=1) self.post_quant_conv = WanCausalConv3d( - rngs=rngs, - in_channels=z_dim, - out_channels=z_dim, - kernel_size=1, + rngs=rngs, + in_channels=z_dim, + out_channels=z_dim, + kernel_size=1, ) self.decoder = WanDecoder3d( - rngs=rngs, - dim=base_dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_upsample=self.temporal_upsample, - dropout=dropout + rngs=rngs, + dim=base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temporal_upsample, + dropout=dropout, ) self.clear_cache() - + def clear_cache(self): - """ Resets cache dictionaries and indices""" + """Resets cache dictionaries and indices""" + def _count_conv3d(module): count = 0 node_types = nnx.graph.iter_graph([module]) for path, value in node_types: if isinstance(value, WanCausalConv3d): - count +=1 + count += 1 return count self._conv_num = _count_conv3d(self.decoder) @@ -848,24 +800,18 @@ def _encode(self, x: jax.Array): # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" - - #self.clear_cache() + + # self.clear_cache() t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder( - x[:, :1, :, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx - ) + out = self.encoder(x[:, :1, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: out_ = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx ) out = jnp.concatenate([out, out_], axis=1) enc = self.quant_conv(out) @@ -875,14 +821,16 @@ def _encode(self, x: jax.Array): # return enc return enc - def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: - """ Encode video into latent distribution.""" + def encode( + self, x: jax.Array, return_dict: bool = True + ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: + """Encode video into latent distribution.""" h = self._encode(x) posterior = FlaxDiagonalGaussianDistribution(h) if not return_dict: - return (posterior, ) + return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) - + def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: self.clear_cache() iter_ = z.shape[1] @@ -890,16 +838,16 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: - out_ = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = jnp.concatenate([out, out_], axis=1) - + out = jnp.clip(out, a_min=-1.0, a_max=1.0) self.clear_cache() if not return_dict: - return (out, ) + return (out,) return FlaxDecoderOutput(sample=out) @@ -912,5 +860,3 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut if not return_dict: return (decoded,) return FlaxDecoderOutput(sample=decoded) - - \ No newline at end of file diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py index cdcecf80b..5cd83bbbd 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py @@ -22,16 +22,12 @@ import flax.linen as nn from ...attention_flax import FlaxFeedForward, Fla -from ...embeddings_flax import ( - get_1d_rotary_pos_embed, - FlaxTimesteps, - FlaxTimestepEmbedding, - PixArtAlphaTextProjection -) +from ...embeddings_flax import (get_1d_rotary_pos_embed, FlaxTimesteps, FlaxTimestepEmbedding, PixArtAlphaTextProjection) from ....configuration_utils import ConfigMixin, flax_register_to_config from ...modeling_flax_utils import FlaxModelMixin + class WanRotaryPosEmbed(nn.Module): attention_head_dim: int patch_size: Tuple[int, int, int] @@ -54,9 +50,9 @@ def __call__(self, hidden_states: Array) -> Array: self.freqs = jnp.concatenate(freqs, dim=1) sizes = [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6 + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, ] cumulative_sizes = jnp.cumsum(jnp.array(sizes)) split_indices = cumulative_sizes[:-1] @@ -76,6 +72,7 @@ def __call__(self, hidden_states: Array) -> Array: return freqs_final + class WanImageEmbeddings(nn.Module): out_features: int dtype: jnp.dtype = jnp.float32 @@ -85,18 +82,15 @@ class WanImageEmbeddings(nn.Module): @nn.compact def __call__(self, encoder_hidden_states_image: Array) -> Array: hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, + dtype=jnp.float32, + param_dtype=jnp.float32, )(encoder_hidden_states_image) hidden_states = FlaxFeedForward( - self.out_features, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + self.out_features, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision )(hidden_states) hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, + dtype=jnp.float32, + param_dtype=jnp.float32, )(hidden_states) return hidden_states @@ -113,43 +107,37 @@ class WanTimeTextImageEmbeddings(nn.Module): @nn.compact def __call__(self, timestep: Array, encoder_hidden_states: Array, encoder_hidden_states_image: Array) -> Array: - + timestep = FlaxTimesteps( - dim=self.time_freq_dim, - flip_sin_to_cos=True, - freq_shift=0, - )(timestep) - temb = FlaxTimestepEmbedding( - time_embed_dim=self.dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype + dim=self.time_freq_dim, + flip_sin_to_cos=True, + freq_shift=0, )(timestep) + temb = FlaxTimestepEmbedding(time_embed_dim=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype)(timestep) timestep_proj = nn.Dense( - self.time_proj_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, + self.time_proj_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, )(nn.silu(temb)) encoder_hidden_states = PixArtAlphaTextProjection( - hidden_size=self.dim, - act_fn="gelu_tanh", - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + hidden_size=self.dim, + act_fn="gelu_tanh", + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, )(encoder_hidden_states) if encoder_hidden_states_image is not None: encoder_hidden_states_image = WanImageEmbeddings( - out_features=self.dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + out_features=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision )(encoder_hidden_states_image) - + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + class WanTransformerBlock(nn.Module): dim: int ffn_dim: int @@ -160,38 +148,36 @@ class WanTransformerBlock(nn.Module): added_kv_proj_dim: Optional[int] = None @nn.compact - def __call__( - self, - hidden_states: Array, - encoder_hidden_states: Array, - temb: Array, - rotary_emb: Array - ): + def __call__(self, hidden_states: Array, encoder_hidden_states: Array, temb: Array, rotary_emb: Array): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) # 1. Self-attention - norm_hidden_states = (nn.LayerNorm( - epsilon=self.eps, - use_bias=False, - use_scale=False, - dtype=jnp.float32, - param_dtype=jnp.float32, - )(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) + norm_hidden_states = ( + nn.LayerNorm( + epsilon=self.eps, + use_bias=False, + use_scale=False, + dtype=jnp.float32, + param_dtype=jnp.float32, + )(hidden_states.astype(jnp.float32)) + * (1 + scale_msa) + + shift_msa + ).astype(hidden_states.dtype) attn_output = FlaxWanAttention( - query_dim=self.dim, - heads=self.num_heads, - dim_head=self.dim // self.num_heads, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - attention_kernel=self.attention_kernel, - mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes, - + query_dim=self.dim, + heads=self.num_heads, + dim_head=self.dim // self.num_heads, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + flash_block_sizes=self.flash_block_sizes, ) + class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A Transformer model for video-like data used in the Wan model. @@ -254,15 +240,15 @@ class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): @nn.compact def __call__( - self, - hidden_states: Array, - timestep: Array, - encoder_hidden_states: Array, - encoder_hidden_states_image: Optional[Array] = None, - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None + self, + hidden_states: Array, + timestep: Array, + encoder_hidden_states: Array, + encoder_hidden_states_image: Optional[Array] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[Array, Dict[str, Array]]: - + inner_dim = self.num_attention_heads * self.attention_head_dim batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -273,32 +259,29 @@ def __call__( # 1. Patch & position embedding rotary_emb = WanRotaryPosEmbed( - attention_head_dim=self.attention_head_dim, - patch_size=self.patch_size, - max_seq_len=self.rope_max_seq_len + attention_head_dim=self.attention_head_dim, patch_size=self.patch_size, max_seq_len=self.rope_max_seq_len )(hidden_states) hidden_states = nn.Conv( - features=inner_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, + features=inner_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, )(hidden_states) - flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? + flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? flattened = hidden_states.reshape(flattened_shape) transposed = jnp.transpose(flattened, (0, 2, 1)) # 2. Condition embeddings # image_embedding_dim=1280 for I2V model temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = WanTimeTextImageEmbeddings( - dim=inner_dim, - time_freq_dim=self.freq_dim, - time_proj_dim=inner_dim * 6, - text_embed_dim=self.text_dim, - image_embed_dim=self.image_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision + dim=inner_dim, + time_freq_dim=self.freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=self.text_dim, + image_embed_dim=self.image_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, )(timestep, encoder_hidden_states, encoder_hidden_states_image) - diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py index e5c1b7145..eedba51cc 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py @@ -23,46 +23,45 @@ BlockSizes = common_types.BlockSizes + class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( - self, - rngs: nnx.Rngs, - model_type='t2v', - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2038, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - eps=1e-6, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", + self, + rngs: nnx.Rngs, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2038, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", ): self.path_embedding = nnx.Conv( - in_dim, - dim, - kernel_size=patch_size, - strides=patch_size, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), - ("batch",) - ), - rngs=rngs + in_dim, + dim, + kernel_size=patch_size, + strides=patch_size, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)), + rngs=rngs, ) def __call__(self, x): diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py index 6bd9fb09b..27b5844a8 100644 --- a/src/maxdiffusion/pipelines/wan/pipeline_wan.py +++ b/src/maxdiffusion/pipelines/wan/pipeline_wan.py @@ -18,10 +18,11 @@ from transformers import AutoTokenizer, UMT5EncoderModel import torch from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel -from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan from ..pipeline_flax_utils import FlaxDiffusionPipeline from ...video_processor import VideoProcessor -#from ...schedulers import FlowMatchEulerDiscreteScheduler +# from ...schedulers import FlowMatchEulerDiscreteScheduler + class WanPipeline(FlaxDiffusionPipeline): @@ -31,40 +32,39 @@ def __init__( text_encoder: UMT5EncoderModel, transformer: WanModel, vae: AutoencoderKLWan, - #scheduler: FlowMatchEulerDiscreteScheduler, + # scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - #scheduler=scheduler + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + # scheduler=scheduler ) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - def _get_t5_prompt_embds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() @@ -81,4 +81,4 @@ def _get_t5_prompt_embds( prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return prompt_embeds \ No newline at end of file + return prompt_embeds diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py index db96f53fd..ec29a353e 100644 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py @@ -29,14 +29,17 @@ broadcast_to_shape_from_left, ) + @flax.struct.dataclass class FlowMatchEulerDiscreteSchedulerState: common: CommonSchedulerState + @dataclass class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): state: FlowMatchEulerDiscreteSchedulerState + class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] @@ -45,27 +48,27 @@ class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): @property def has_state(self): return True - + @register_to_config def __init__( - self, - num_train_timesteps: int = 1000, - shift: float = 1.0, - use_dynamic_shifting: bool = False, - base_shift: Optional[float] = 0.5, - max_shift: Optional[float] = 1.15, - base_image_seq_len: Optional[int] = 256, - max_image_seq_len: Optional[int] = 4096, - invert_sigmas: bool = False, - shift_terminal: Optional[float] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - time_shift_type: str = "exponential", - dtype: jnp.dtype = jnp.float32 + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype - + def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: if common is None: - common = CommonSchedulerState.create(self) \ No newline at end of file + common = CommonSchedulerState.create(self) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index a2dd50bff..98e49f130 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -26,133 +26,135 @@ import pytest from absl.testing import absltest from ..models.wan.autoencoder_kl_wan import ( - WanCausalConv3d, - WanUpsample, - AutoencoderKLWan, - WanEncoder3d, - WanMidBlock, - WanResidualBlock, - WanRMS_norm, - WanResample, - ZeroPaddedConv2D, - WanAttentionBlock + WanCausalConv3d, + WanUpsample, + AutoencoderKLWan, + WanEncoder3d, + WanMidBlock, + WanResidualBlock, + WanRMS_norm, + WanResample, + ZeroPaddedConv2D, + WanAttentionBlock, ) CACHE_T = 2 + class TorchWanRMS_norm(nn.Module): - r""" - A custom RMS normalization layer. - - Args: - dim (int): The number of dimensions to normalize over. - channel_first (bool, optional): Whether the input tensor has channels as the first dimension. - Default is True. - images (bool, optional): Whether the input represents image data. Default is True. - bias (bool, optional): Whether to include a learnable bias term. Default is False. - """ - - def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: - super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - - self.channel_first = channel_first - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 - - def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + r""" + A custom RMS normalization layer. -class TorchWanResample(nn.Module): - r""" - A custom resampling module for 2D and 3D data. - - Args: - dim (int): The number of input/output channels. - mode (str): The resampling mode. Must be one of: - - 'none': No resampling (identity operation). - - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. - - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. - - 'downsample2d': 2D downsampling with zero-padding and convolution. - - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. - """ - - def __init__(self, dim: int, mode: str) -> None: - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == "upsample2d": - self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - elif mode == "upsample3d": - self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - - elif mode == "downsample2d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class TorchWanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 else: - self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() - if self.mode == "upsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 - else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) - if feat_cache[idx] == "Rep": - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.resample(x) - x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) - - if self.mode == "downsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - return x + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + class WanVaeTest(unittest.TestCase): + def setUp(self): WanVaeTest.dummy_data = {} - + def test_clear_cache(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -161,7 +163,7 @@ def test_clear_cache(self): def test_wanrms_norm(self): """Test against the Pytorch implementation""" - + # --- Test Case 1: images == True --- dim = 96 input_shape = (1, 96, 1, 480, 720) @@ -170,7 +172,7 @@ def test_wanrms_norm(self): input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() - + key = jax.random.key(0) rngs = nnx.Rngs(key) wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs) @@ -184,7 +186,7 @@ def test_wanrms_norm(self): input = torch.ones(input_shape) torch_output = model(input) torch_output_np = torch_output.detach().numpy() - + key = jax.random.key(0) rngs = nnx.Rngs(key) wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs, images=False) @@ -192,7 +194,7 @@ def test_wanrms_norm(self): output = wanrms_norm(dummy_input) output_np = np.array(output) assert np.allclose(output_np, torch_output_np) == True - + def test_zero_padded_conv(self): key = jax.random.key(0) @@ -200,19 +202,14 @@ def test_zero_padded_conv(self): dim = 96 kernel_size = 3 - stride= (2, 2) + stride = (2, 2) resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride)) input_shape = (1, 96, 480, 720) input = torch.ones(input_shape) output_torch = resample(input) assert output_torch.shape == (1, 96, 240, 360) - model = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(1, 3, 3), - stride=(1, 2, 2) - ) + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) dummy_input = jnp.ones(input_shape) dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) output = model(dummy_input) @@ -220,7 +217,7 @@ def test_zero_padded_conv(self): assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): - batch_size=1 + batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 @@ -237,59 +234,49 @@ def test_wan_upsample(self): # --- Test Case 1: depth == 1 --- output = upsample(dummy_input) assert output.shape == (1, 1, 64, 64, 3) - + def test_wan_resample(self): # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity key = jax.random.key(0) rngs = nnx.Rngs(key) - + # --- Test Case 1: downsample2d --- batch = 1 dim = 96 t = 1 h = 480 w = 720 - mode="downsample2d" + mode = "downsample2d" input_shape = (batch, dim, t, h, w) expected_output_shape = (1, dim, 1, 240, 360) # output dim should be (1, 96, 1, 480, 720) dummy_input = torch.ones(input_shape) - torch_wan_resample = TorchWanResample( - dim=dim, - mode=mode - ) + torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) - assert torch_output.shape == (batch, dim, t, h // 2, w //2) + assert torch_output.shape == (batch, dim, t, h // 2, w // 2) - wan_resample = WanResample( - dim, - mode=mode, - rngs=rngs - ) + wan_resample = WanResample(dim, mode=mode, rngs=rngs) # channels is always last here input_shape = (batch, t, h, w, dim) dummy_input = jnp.ones(input_shape) output = wan_resample(dummy_input) - assert output.shape == (batch, t, h//2, h//2, dim) + assert output.shape == (batch, t, h // 2, h // 2, dim) breakpoint() - + # --- Test Case 1: downsample3d --- dim = 192 input_shape = (1, dim, 1, 240, 360) - torch_wan_resample = WanResample( - dim=dim, - mode="downsample3d" - ) + torch_wan_resample = WanResample(dim=dim, mode="downsample3d") def test_3d_conv(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - batch_size=1 + batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 out_channels = 16 - kernel_d, kernel_h, kernel_w = 3, 3, 3 # Kernel size (Depth, Height, Width) - padding_d, padding_h, padding_w = 1, 1, 1 # Base padding (Depth, Height, Width) + kernel_d, kernel_h, kernel_w = 3, 3, 3 # Kernel size (Depth, Height, Width) + padding_d, padding_h, padding_w = 1, 1, 1 # Base padding (Depth, Height, Width) # Create dummy input data dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) @@ -300,11 +287,11 @@ def test_3d_conv(self): # Instantiate the module causal_conv_layer = WanCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(kernel_d, kernel_h, kernel_w), - padding=(padding_d, padding_h, padding_w), - rngs=rngs # Pass rngs for initialization + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization ) # --- Test Case 1: No Cache --- @@ -316,7 +303,7 @@ def test_3d_conv(self): assert output_with_cache.shape == (1, 10, 32, 32, 16) # --- Test Case 3: With Cache larger than padding --- - larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) @@ -335,9 +322,9 @@ def test_wan_residual(self): expected_output_shape = (batch, t, height, width, dim) wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, ) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) @@ -349,9 +336,9 @@ def test_wan_residual(self): expected_output_shape = (batch, t, height, width, out_dim) wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, + in_dim=in_dim, + out_dim=out_dim, + rngs=rngs, ) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) @@ -365,11 +352,8 @@ def test_wan_attention(self): t = 1 height = 60 width = 90 - input_shape=(batch, t, height, width, dim) - wan_attention = WanAttentionBlock( - dim=dim, - rngs=rngs - ) + input_shape = (batch, t, height, width, dim) + wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) output = wan_attention(dummy_input) assert output.shape == input_shape @@ -383,9 +367,7 @@ def test_wan_midblock(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock( - dim=dim, rngs=rngs - ) + wan_midblock = WanMidBlock(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) output = wan_midblock(dummy_input) assert output.shape == input_shape @@ -400,13 +382,13 @@ def test_wan_decode(self): attn_scales = [] temperal_downsample = [False, True, True] wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, ) batch = 1 @@ -433,13 +415,13 @@ def test_wan_encode(self): attn_scales = [] temperal_downsample = [False, True, True] wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, ) batch = 1 channels = 3 @@ -451,5 +433,6 @@ def test_wan_encode(self): output = wan_vae.encode(input) assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py index 2da782b46..c29485118 100644 --- a/src/maxdiffusion/video_processor.py +++ b/src/maxdiffusion/video_processor.py @@ -23,91 +23,91 @@ class VideoProcessor(VaeImageProcessor): - r"""Simple video processor.""" + r"""Simple video processor.""" - def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: - r""" - Preprocesses input video(s). + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). - Args: - video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): - The input video. It can be one of the following: - * List of the PIL images. - * List of list of PIL images. - * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). - * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). - * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, - width)`). - * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). - * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, - num_channels)`. - * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, - width)`. - height (`int`, *optional*, defaults to `None`): - The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to - get default height. - width (`int`, *optional*`, defaults to `None`): - The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get - the default width. - """ - if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: - warnings.warn( - "Passing `video` as a list of 5d np.ndarray is deprecated." - "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", - FutureWarning, - ) - video = np.concatenate(video, axis=0) - if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: - warnings.warn( - "Passing `video` as a list of 5d torch.Tensor is deprecated." - "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", - FutureWarning, - ) - video = torch.cat(video, axis=0) + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) - # ensure the input is a list of videos: - # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) - # - if it is a single video, it is convereted to a list of one video. - if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: - video = list(video) - elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): - video = [video] - elif isinstance(video, list) and is_valid_image_imagelist(video[0]): - video = video - else: - raise ValueError( - "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" - ) + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) - video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) - # move the number of channels before the number of frames. - video = video.permute(0, 2, 1, 3, 4) + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) - return video + return video - def postprocess_video( - self, video: torch.Tensor, output_type: str = "np" - ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: - r""" - Converts a video tensor to a list of frames for export. + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. - Args: - video (`torch.Tensor`): The video as a tensor. - output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. - """ - batch_size = video.shape[0] - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = self.postprocess(batch_vid, output_type) - outputs.append(batch_output) + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) - if output_type == "np": - outputs = np.stack(outputs) - elif output_type == "pt": - outputs = torch.stack(outputs) - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - return outputs + return outputs From 089f8ac5a0ed00b56f53fce4226b6c2232e89573 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 29 Apr 2025 20:32:34 +0000 Subject: [PATCH 15/54] fix unit tests --- src/maxdiffusion/tests/wan_vae_test.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 98e49f130..fc3f1cb6d 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -221,19 +221,13 @@ def test_wan_upsample(self): in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 - dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) + dummy_input = jnp.ones((batch_size * in_depth, in_height, in_width, in_channels)) upsample = WanUpsample(scale_factor=(2.0, 2.0)) # --- Test Case 1: depth > 1 --- output = upsample(dummy_input) - assert output.shape == (1, 10, 64, 64, 3) - - in_depth = 1 - dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels)) - # --- Test Case 1: depth == 1 --- - output = upsample(dummy_input) - assert output.shape == (1, 1, 64, 64, 3) + assert output.shape == (10, 64, 64, 3) def test_wan_resample(self): # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity @@ -260,13 +254,7 @@ def test_wan_resample(self): input_shape = (batch, t, h, w, dim) dummy_input = jnp.ones(input_shape) output = wan_resample(dummy_input) - assert output.shape == (batch, t, h // 2, h // 2, dim) - breakpoint() - - # --- Test Case 1: downsample3d --- - dim = 192 - input_shape = (1, dim, 1, 240, 360) - torch_wan_resample = WanResample(dim=dim, mode="downsample3d") + assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): key = jax.random.key(0) From 40d423d097efb3527e0d7778e863330128d08dc4 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 2 May 2025 16:07:00 +0000 Subject: [PATCH 16/54] e2e wan vae with weights loading. Still not fully working. --- src/maxdiffusion/configuration_utils.py | 3 +- src/maxdiffusion/generate_wan.py | 1 + src/maxdiffusion/models/flux/util.py | 21 +-- .../models/modeling_flax_pytorch_utils.py | 49 +++++- .../models/wan/autoencoder_kl_wan.py | 26 +-- src/maxdiffusion/models/wan/wan_utils.py | 78 +++++++++ src/maxdiffusion/tests/wan_vae_test.py | 16 +- src/maxdiffusion/utils/__init__.py | 3 +- src/maxdiffusion/utils/export_utils.py | 114 ++++++++++-- src/maxdiffusion/utils/import_utils.py | 23 +++ src/maxdiffusion/utils/loading_utils copy.py | 162 ++++++++++++++++++ src/maxdiffusion/utils/loading_utils.py | 88 +++++++++- 12 files changed, 527 insertions(+), 57 deletions(-) create mode 100644 src/maxdiffusion/models/wan/wan_utils.py create mode 100644 src/maxdiffusion/utils/loading_utils copy.py diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 8bcb6d80d..bfd420f72 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -464,7 +464,8 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove flax internal keys if hasattr(cls, "_flax_internal_args"): for arg in cls._flax_internal_args: - expected_keys.remove(arg) + if arg in expected_keys: + expected_keys.remove(arg) # 2. Remove attributes that cannot be expected from expected config attributes # remove keys to be ignored diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 3a79621d3..ce33c6806 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -26,6 +26,7 @@ from absl import app from transformers import AutoTokenizer, UMT5EncoderModel from maxdiffusion import pyconfig, max_logging +from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWan from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 362a39171..26371e4db 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -11,7 +11,11 @@ from jax import numpy as jnp from safetensors import safe_open -from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) +from ..modeling_flax_pytorch_utils import ( + rename_key, + rename_key_and_reshape_tensor, + torch2jax +) from maxdiffusion import max_logging @@ -32,21 +36,6 @@ class FluxParams: rngs: Array param_dtype: DTypeLike - -def torch2jax(torch_tensor: torch.Tensor) -> Array: - is_bfloat16 = torch_tensor.dtype == torch.bfloat16 - if is_bfloat16: - # upcast the tensor to fp32 - torch_tensor = torch_tensor.float() - - if torch.device.type != "cpu": - torch_tensor = torch_tensor.to("cpu") - - numpy_value = torch_tensor.numpy() - jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) - return jax_array - - @dataclass class ModelSpec: params: FluxParams diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 9552c69f1..74c7fce50 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -15,18 +15,57 @@ """ PyTorch - Flax general utilities.""" import re +import torch import jax import jax.numpy as jnp from flax.linen import Partitioned from flax.traverse_util import flatten_dict, unflatten_dict from flax.core.frozen_dict import unfreeze from jax.random import PRNGKey - +from chex import Array from ..utils import logging +from .. import max_logging logger = logging.get_logger(__name__) +def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): + """ + expected_pytree: dict - a pytree that comes from initializing the model. + new_pytree: dict - a pytree that has been created from pytorch weights. + """ + expected_pytree = flatten_dict(expected_pytree) + if len(expected_pytree.keys()) != len(new_pytree.keys()): + set1 = set(expected_pytree.keys()) + set2 = set(new_pytree.keys()) + missing_keys = set1 ^ set2 + max_logging.log(f"missing keys : {missing_keys}") + for key in expected_pytree.keys(): + if key in new_pytree.keys(): + try: + expected_pytree_shape = expected_pytree[key].shape + except Exception: + expected_pytree_shape = expected_pytree[key].value.shape + if expected_pytree_shape != new_pytree[key].shape: + max_logging.log(f"shape mismatch for {key}") + max_logging.log( + f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}" + ) + else: + max_logging.log(f"key: {key} not found...") + +def torch2jax(torch_tensor: torch.Tensor) -> Array: + is_bfloat16 = torch_tensor.dtype == torch.bfloat16 + if is_bfloat16: + # upcast the tensor to fp32 + torch_tensor = torch_tensor.float() + + if torch.device.type != "cpu": + torch_tensor = torch_tensor.to("cpu") + + numpy_value = torch_tensor.numpy() + jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) + return jax_array def rename_key(key): regex = r"\w+[.]\d+" @@ -93,6 +132,12 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor + + # 3d conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 5: + pt_tensor = pt_tensor.transpose(2, 3, 4, 1, 0) + return renamed_pt_tuple_key, pt_tensor # linear layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) @@ -103,6 +148,8 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic # old PyTorch layer norm weight renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) if pt_tuple_key[-1] == "gamma": + renamed_pt_tuple_key = pt_tuple_key + pt_tensor = pt_tensor.flatten() return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm bias diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 2cf43f69c..ceb7bbe2f 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -200,8 +200,6 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", ): - kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") - stride = _canonicalize_tuple(stride, 3, "stride") self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) def __call__(self, x): @@ -233,7 +231,7 @@ def __init__( nnx.Conv( dim, dim // 2, - kernel_size=(1, 3, 3), + kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, @@ -241,11 +239,11 @@ def __init__( ) elif mode == "upsample3d": self.resample = nnx.Sequential( - WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"), + WanUpsample(scale_factor=(2.0, 2.0), method="nearest"), nnx.Conv( dim, dim // 2, - kernel_size=(1, 3, 3), + kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, @@ -259,11 +257,9 @@ def __init__( padding=(1, 0, 0), ) elif mode == "downsample2d": - # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) elif mode == "downsample3d": - # TODO - do I need to transpose? - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) self.time_conv = WanCausalConv3d( rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) ) @@ -334,7 +330,6 @@ def __init__( self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1) self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) - self.dropout = nnx.Dropout(dropout, rngs=rngs) self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1) self.conv_shortcut = ( WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1) @@ -363,7 +358,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm2(x) x = self.nonlinearity(x) - x = self.dropout(x) if feat_cache is not None: idx = feat_idx[0] @@ -384,8 +378,8 @@ class WanAttentionBlock(nnx.Module): def __init__(self, dim: int, rngs: nnx.Rngs): self.dim = dim self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) - self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=1, rngs=rngs) - self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=1, rngs=rngs) + self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs) + self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs) def __call__(self, x: jax.Array): batch_size, time, height, width, channels = x.shape @@ -801,8 +795,6 @@ def _encode(self, x: jax.Array): x = jnp.transpose(x, (0, 2, 3, 4, 1)) assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" - # self.clear_cache() - t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): @@ -854,8 +846,8 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: if z.shape[-1] != self.z_dim: # reshape channel last for JAX - x = jnp.transpose(x, (0, 2, 3, 4, 1)) - assert x.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {x.shape}" + z = jnp.transpose(z, (0, 2, 3, 4, 1)) + assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}" decoded = self._decode(z).sample if not return_dict: return (decoded,) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py new file mode 100644 index 000000000..1ff994c6d --- /dev/null +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -0,0 +1,78 @@ +import jax +import jax.numpy as jnp +from maxdiffusion import max_logging +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from flax.traverse_util import flatten_dict, unflatten_dict +from ..modeling_flax_pytorch_utils import ( + rename_key, + rename_key_and_reshape_tensor, + torch2jax, + validate_flax_state_dict +) + +def _tuple_str_to_int(in_tuple): + out_list = [] + for item in in_tuple: + try: + out_list.append(int(item)) + except: + out_list.append(item) + return tuple(out_list) + + +def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors") + #breakpoint() + max_logging.log(f"Load and port Wan 2.1 VAE on {device}") + + if ckpt_path is not None: + tensors = {} + with safe_open(ckpt_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + # Order matters + renamed_pt_key = renamed_pt_key.replace("up_blocks_", "up_blocks.") + renamed_pt_key = renamed_pt_key.replace("mid_block_", "mid_block.") + renamed_pt_key = renamed_pt_key.replace("down_blocks_", "down_blocks.") + + renamed_pt_key = renamed_pt_key.replace("conv_in.bias", "conv_in.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv_in.weight", "conv_in.conv.weight") + renamed_pt_key = renamed_pt_key.replace("conv_out.bias", "conv_out.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv_out.weight", "conv_out.conv.weight") + renamed_pt_key = renamed_pt_key.replace("attentions_", "attentions.") + renamed_pt_key = renamed_pt_key.replace("resnets_", "resnets.") + renamed_pt_key = renamed_pt_key.replace("upsamplers_", "upsamplers.") + renamed_pt_key = renamed_pt_key.replace("resample_", "resample.") + renamed_pt_key = renamed_pt_key.replace("conv1.bias", "conv1.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv1.weight", "conv1.conv.weight") + renamed_pt_key = renamed_pt_key.replace("conv2.bias", "conv2.conv.bias") + renamed_pt_key = renamed_pt_key.replace("conv2.weight", "conv2.conv.weight") + renamed_pt_key = renamed_pt_key.replace("time_conv.bias", "time_conv.conv.bias") + renamed_pt_key = renamed_pt_key.replace("time_conv.weight", "time_conv.conv.weight") + renamed_pt_key = renamed_pt_key.replace("quant_conv", "quant_conv.conv") + renamed_pt_key = renamed_pt_key.replace("conv_shortcut", "conv_shortcut.conv") + if "decoder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("resample.1.bias", "resample.layers.1.bias") + renamed_pt_key = renamed_pt_key.replace("resample.1.weight", "resample.layers.1.weight") + if "encoder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("resample.1", "resample.conv") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + else: + raise FileNotFoundError(f"Path {ckpt_path} was not found") + + return flax_state_dict \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index fc3f1cb6d..0ae3f3a26 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -29,7 +29,6 @@ WanCausalConv3d, WanUpsample, AutoencoderKLWan, - WanEncoder3d, WanMidBlock, WanResidualBlock, WanRMS_norm, @@ -37,6 +36,7 @@ ZeroPaddedConv2D, WanAttentionBlock, ) +from ..models.wan.wan_utils import load_wan_vae CACHE_T = 2 @@ -421,6 +421,20 @@ def test_wan_encode(self): output = wan_vae.encode(input) assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + # def test_load_checkpoint(self): + # pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + # key = jax.random.key(0) + # rngs = nnx.Rngs(key) + # wan_vae = AutoencoderKLWan.from_config( + # pretrained_model_name_or_path, + # subfolder="vae", + # rngs=rngs + # ) + # graphdef, state = nnx.split(wan_vae) + # params = state.to_pure_dict() + # # This replaces random params with the model. + # params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + if __name__ == "__main__": absltest.main() diff --git a/src/maxdiffusion/utils/__init__.py b/src/maxdiffusion/utils/__init__.py index 199b7c156..256934b3b 100644 --- a/src/maxdiffusion/utils/__init__.py +++ b/src/maxdiffusion/utils/__init__.py @@ -83,7 +83,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image +from .loading_utils import load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( @@ -103,7 +103,6 @@ convert_unet_state_dict_to_peft, ) - logger = get_logger(__name__) diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index cd31a7baa..461922e94 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -3,14 +3,14 @@ import struct import tempfile from contextlib import contextmanager -from typing import List +from typing import List, Optional, Union import numpy as np import PIL.Image import PIL.ImageOps -from .import_utils import BACKENDS_MAPPING, is_opencv_available +from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available from .logging import get_logger @@ -110,19 +110,97 @@ def export_to_obj(mesh, output_obj_path: str = None): with open(output_obj_path, "w") as f: f.writelines("\n".join(combined_data)) - -def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: - if is_opencv_available(): - import cv2 - else: - raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - h, w, c = video_frames[0].shape - video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) - for i in range(len(video_frames)): - img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) - video_writer.write(img) - return output_video_path +def _legacy_export_to_video( + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 +): + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + + return output_video_path + +def export_to_video( + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], + output_video_path: str = None, + fps: int = 10, + quality: float = 5.0, + bitrate: Optional[int] = None, + macro_block_size: Optional[int] = 16, +) -> str: + """ + quality: + Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to + prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. + Specifying a fixed bitrate using `bitrate` disables this parameter. + + bitrate: + Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. + Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter + rather than specifiying a fixed bitrate with this parameter. + + macro_block_size: + Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number + imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs + are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic + feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some + codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. + """ + # TODO: Dhruv. Remove by Diffusers release 0.33.0 + # Added to prevent breaking existing code + if not is_imageio_available(): + logger.warning( + ( + "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" + "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" + "Support for the OpenCV backend will be deprecated in a future Diffusers version" + ) + ) + return _legacy_export_to_video(video_frames, output_video_path, fps) + + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + ( + "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" + "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" + ) + ) + + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + + with imageio.get_writer( + output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size + ) as writer: + for frame in video_frames: + writer.append_data(frame) + + return output_video_path diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index 99d88336a..224bad16c 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -51,6 +51,19 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -105,6 +118,7 @@ except importlib_metadata.PackageNotFoundError: _transformers_available = False +_imageio_available, _imageio_version = _is_package_available("imageio") _inflect_available = importlib.util.find_spec("inflect") is not None try: @@ -284,6 +298,8 @@ except importlib_metadata.PackageNotFoundError: _peft_available = False +def is_imageio_available(): + return _imageio_available def is_torch_available(): return _torch_available @@ -486,6 +502,11 @@ def is_peft_available(): {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` """ +# docstyle-ignore +IMAGEIO_IMPORT_ERROR = """ +{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -506,6 +527,7 @@ def is_peft_available(): ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), ] ) @@ -705,3 +727,4 @@ def _get_module(self, module_name: str): def __reduce__(self): return (self.__class__, (self._name, self.__file__, self._import_structure)) + diff --git a/src/maxdiffusion/utils/loading_utils copy.py b/src/maxdiffusion/utils/loading_utils copy.py new file mode 100644 index 000000000..fd66aaa4d --- /dev/null +++ b/src/maxdiffusion/utils/loading_utils copy.py @@ -0,0 +1,162 @@ +import os +import tempfile +from typing import Any, Callable, List, Optional, Tuple, Union +from urllib.parse import unquote, urlparse + +import PIL.Image +import PIL.ImageOps +import requests + +from .import_utils import BACKENDS_MAPPING, is_imageio_available + + +def load_image( + image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None +) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): + A conversion method to apply to the image after loading it. When set to `None` the image will be converted + "RGB". + + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." + ) + + image = PIL.ImageOps.exif_transpose(image) + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + + return image + + +def load_video( + video: str, + convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, +) -> List[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images + + +# Taken from `transformers`. +def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + return module, tensor_name + + +def get_submodule_by_name(root_module, module_path: str): + current = root_module + parts = module_path.split(".") + for part in parts: + if part.isdigit(): + idx = int(part) + current = current[idx] # e.g., for nn.ModuleList or nn.Sequential + else: + current = getattr(current, part) + return current diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 07c08e726..77caa39e0 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -1,9 +1,12 @@ import os -from typing import Union +from typing import Any, Callable, List, Optional, Tuple, Union import PIL.Image import PIL.ImageOps import requests +import tempfile +from urllib.parse import unquote, urlparse +from .import_utils import BACKENDS_MAPPING, is_imageio_available def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: @@ -33,3 +36,86 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image + +def load_video( + video: str, + convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, +) -> List[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images \ No newline at end of file From 34ebdbea48b2f6f874d1106096a5c05f174d691d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 7 May 2025 18:13:23 +0000 Subject: [PATCH 17/54] debug statements --- .../models/wan/autoencoder_kl_wan.py | 118 +++++++++++++++--- src/maxdiffusion/models/wan/wan_utils.py | 1 - src/maxdiffusion/tests/wan_vae_test.py | 15 +-- 3 files changed, 107 insertions(+), 27 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index ceb7bbe2f..06bd37aaa 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -23,7 +23,7 @@ from ..modeling_flax_utils import FlaxModelMixin from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) - +import numpy as np BlockSizes = common_types.BlockSizes CACHE_T = 2 @@ -93,11 +93,16 @@ def __init__( rngs=rngs, ) - def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Array: + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: + print("wanCausalConv3d, x min: ", np.min(x)) + print("wanCausalConv3d, x max: ", np.max(x)) current_padding = list(self._causal_padding) # Mutable copy padding_needed = self._depth_padding_before if cache_x is not None and padding_needed > 0: + print("WanCausalConv3d, cache.shape: ", cache_x.shape) + print("wanCausalConv3d, cache_x min: ", np.min(cache_x)) + print("wanCausalConv3d, cache_x max: ", np.max(cache_x)) # Ensure cache has same spatial/channel dims, potentially different depth assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" cache_len = cache_x.shape[1] @@ -105,21 +110,34 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Arr padding_needed -= cache_len if padding_needed < 0: + print("wanCausanConv3d, padding_needed < 0") # Cache longer than needed padding, trim from start x = x[:, -padding_needed:, ...] current_padding[1] = (0, 0) # No explicit padding needed now else: # Update depth padding needed + print("wanCausanConv3d, padding_needed > 0") current_padding[1] = (padding_needed, 0) # Apply padding if any dimension requires it padding_to_apply = tuple(current_padding) + print("WanCausalConv3d, before padding x shape: ", x.shape) if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): + print("WanCausalConv3d, applying padding") x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: + print("WanCausalConv3d, NOT applying padding") x_padded = x + print("WanCausalConv3d, x shape: ", x_padded.shape) + print("wanCausalConv3d, x min: ", np.min(x_padded)) + print("wanCausalConv3d, x max: ", np.max(x_padded)) + # if idx == 12: + # breakpoint() out = self.conv(x_padded) + print("WanCausalConv3d, after conv, x shape: ", out.shape) + print("wanCausalConv3d, x min: ", np.min(out)) + print("wanCausalConv3d, x max: ", np.max(out)) return out @@ -346,11 +364,13 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] + print("Before conv1, idx: ", idx) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - - x = self.conv1(x, feat_cache[idx]) + x = self.conv1(x, feat_cache[idx], idx) + # if idx == 4: + # breakpoint() feat_cache[idx] = cache_x feat_idx[0] += 1 else: @@ -358,19 +378,34 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm2(x) x = self.nonlinearity(x) + idx = feat_idx[0] + # if idx == 4: + # breakpoint() if feat_cache is not None: idx = feat_idx[0] + print("Residual block, idx: ", idx) + # if idx == 14: + # breakpoint() + print("cache_x min: ", np.min(cache_x)) + print("cache_x max: ", np.max(cache_x)) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) + print("cache_x min: ", np.min(cache_x)) + print("cache_x max: ", np.max(cache_x)) + #breakpoint() x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv2(x) - - return x + h + print("before conv shortcut add: x min", np.min(x)) + print("before conv shortcut add: x max", np.max(x)) + x = x + h + print("after conv shortcut add: x min: ", np.min(x)) + print("after conv shortcut add: x max: ", np.max(x)) + return x class WanAttentionBlock(nnx.Module): @@ -382,26 +417,51 @@ def __init__(self, dim: int, rngs: nnx.Rngs): self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs) def __call__(self, x: jax.Array): - batch_size, time, height, width, channels = x.shape + identity = x + batch_size, time, height, width, channels = x.shape x = x.reshape(batch_size * time, height, width, channels) x = self.norm(x) qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - - qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + #breakpoint() + #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - q, k, v = jnp.split(qkv, 3, axis=-1) - - x = jax.nn.dot_product_attention(q, k, v) + print("qkv min: ", np.min(qkv)) + print("qkv max: ", np.max(qkv)) + #q, k, v = jnp.split(qkv, 3, axis=-1) + q, k, v = jnp.split(qkv, 3, axis=-2) + print("q min: ", np.min(q)) + print("q max: ", np.max(q)) + print("k min: ", np.min(k)) + print("k min: ", np.max(k)) + print("v min: ", np.min(v)) + print("v min: ", np.max(v)) + #breakpoint() + q = jnp.transpose(q, (0, 1, 3, 2)) + k = jnp.transpose(k, (0, 1, 3, 2)) + v = jnp.transpose(v, (0, 1, 3, 2)) + import torch + import torch.nn.functional as F + q = torch.tensor(np.array(q, dtype=np.float32)) + k = torch.tensor(np.array(k, dtype=np.float32)) + v = torch.tensor(np.array(v, dtype=np.float32)) + #x = jax.nn.dot_product_attention(q, k, v) + x = F.scaled_dot_product_attention(q, k, v) + print("attn min: ", torch.min(x)) + print("attn max: ", torch.max(x)) + #breakpoint() + x = jnp.array(x.detach().numpy()) x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) # output projection x = self.proj(x) - + #breakpoint() # Reshape back x = x.reshape(batch_size, time, height, width, channels) + #breakpoint() return x + identity @@ -419,11 +479,20 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + print("WanMidblock...") x = self.resnets[0](x, feat_cache, feat_idx) + print("WanMidBlock resnets[0], x min: ", np.min(x)) + print("WanMidBlock resnets[0], x max: ", np.max(x)) for attn, resnet in zip(self.attentions, self.resnets[1:]): + print("WanMidBlock, for loop, attn len: ", len(self.attentions)) + print("WanMidBlock, for loop, resnets len: ", len(self.resnets)) if attn is not None: x = attn(x) + print("WanMidBlock attn[0], x min: ", np.min(x)) + print("WanMidBlock attn[0], x max: ", np.max(x)) x = resnet(x, feat_cache, feat_idx) + print("WanMidBlock resnets[i], x min: ", np.min(x)) + print("WanMidBlock resnets[i], x max: ", np.max(x)) return x @@ -589,7 +658,7 @@ def __init__( self, rngs: nnx.Rngs, dim: int = 128, - z_dim: int = 128, + z_dim: int = 4, dim_mult: List[int] = [1, 2, 4, 4], num_res_blocks: int = 2, attn_scales=List[float], @@ -662,7 +731,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): ## middle x = self.mid_block(x, feat_cache, feat_idx) - + #breakpoint() ## upsamples for up_block in self.up_blocks: x = up_block(x, feat_cache, feat_idx) @@ -810,7 +879,6 @@ def _encode(self, x: jax.Array): mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] enc = jnp.concatenate([mu, logvar], axis=-1) self.clear_cache() - # return enc return enc def encode( @@ -833,10 +901,22 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - out = jnp.concatenate([out, out_], axis=1) - - out = jnp.clip(out, a_min=-1.0, a_max=1.0) + print("out_.shape: ", out_.shape) + print("out_ min: ", np.min(out_)) + print("out_ max: ", np.max(out_)) + print("out.shape: ", out.shape) + print("out min: ", np.min(out)) + print("out max: ", np.max(out)) + for i in range(len(self._feat_map)): + if isinstance(self._feat_map[i], jax.Array): + print("i: ", i) + print("min: ", np.min(self._feat_map[i])) + print("max: ", np.max(self._feat_map[i])) + else: + print(f"feat_map[{i}] : {self._feat_map[i]}") + # breakpoint() + out = jnp.clip(out, min=-1.0, max=1.0) self.clear_cache() if not return_dict: return (out,) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 1ff994c6d..2a8517954 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -26,7 +26,6 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: with jax.default_device(device): if hf_download: ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors") - #breakpoint() max_logging.log(f"Load and port Wan 2.1 VAE on {device}") if ckpt_path is not None: diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 0ae3f3a26..e71ddbde2 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -40,7 +40,6 @@ CACHE_T = 2 - class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer. @@ -92,16 +91,18 @@ def __init__(self, dim: int, mode: str) -> None: WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) ) elif mode == "upsample3d": - self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + # self.resample = nn.Sequential( + # WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + # ) + # self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + raise Exception("downsample3d not supported") elif mode == "downsample2d": self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + raise Exception("downsample3d not supported") + #self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + #self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: self.resample = nn.Identity() From 04f4909982697df6cc74c1a80a4dbea28ca58708 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 7 May 2025 21:39:55 +0000 Subject: [PATCH 18/54] solves distored decoded video. Now video is jittery, but frames are ok. --- .../models/wan/autoencoder_kl_wan.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 06bd37aaa..296e58518 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -286,7 +286,7 @@ def __init__( def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: # Input x: (N, D, H, W, C), assume C = self.dim - n, d, h, w, c = x.shape + b, t, h, w, c = x.shape assert c == self.dim if self.mode == "upsample3d": @@ -308,14 +308,14 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: x = self.time_conv(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 - x = x.reshape(n, 2, d, h, w, c) - x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2) - x = x.reshape(n, d * 2, h, w, c) - d = x.shape[1] - x = x.reshape(n * d, h, w, c) + x = x.reshape(b, t, h, w, 2, c) + x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1) + x = x.reshape(b, t * 2, h, w, c) + t = x.shape[1] + x = x.reshape(b * t, h, w, c) x = self.resample(x) h_new, w_new, c_new = x.shape[1:] - x = x.reshape(n, d, h_new, w_new, c_new) + x = x.reshape(b, t, h_new, w_new, c_new) if self.mode == "downsample3d": if feat_cache is not None: @@ -425,7 +425,6 @@ def __call__(self, x: jax.Array): x = self.norm(x) qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - #breakpoint() #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) @@ -439,21 +438,10 @@ def __call__(self, x: jax.Array): print("k min: ", np.max(k)) print("v min: ", np.min(v)) print("v min: ", np.max(v)) - #breakpoint() q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) v = jnp.transpose(v, (0, 1, 3, 2)) - import torch - import torch.nn.functional as F - q = torch.tensor(np.array(q, dtype=np.float32)) - k = torch.tensor(np.array(k, dtype=np.float32)) - v = torch.tensor(np.array(v, dtype=np.float32)) - #x = jax.nn.dot_product_attention(q, k, v) - x = F.scaled_dot_product_attention(q, k, v) - print("attn min: ", torch.min(x)) - print("attn max: ", torch.max(x)) - #breakpoint() - x = jnp.array(x.detach().numpy()) + x = jax.nn.dot_product_attention(q, k, v) x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) # output projection @@ -696,7 +684,7 @@ def __init__( upsample_mode = None if i != len(dim_mult) - 1: upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - # Crete and add the upsampling block + # Create and add the upsampling block up_block = WanUpBlock( in_dim=in_dim, out_dim=out_dim, @@ -731,7 +719,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): ## middle x = self.mid_block(x, feat_cache, feat_idx) - #breakpoint() ## upsamples for up_block in self.up_blocks: x = up_block(x, feat_cache, feat_idx) From 66146b9dcf63630b16cc10ba1a341c4fa2ee455f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 01:00:53 +0000 Subject: [PATCH 19/54] fixes jittery decoder frames in vae. --- .../models/wan/autoencoder_kl_wan.py | 86 +++---------------- 1 file changed, 14 insertions(+), 72 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 296e58518..f9603e51e 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -94,15 +94,10 @@ def __init__( ) def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: - print("wanCausalConv3d, x min: ", np.min(x)) - print("wanCausalConv3d, x max: ", np.max(x)) current_padding = list(self._causal_padding) # Mutable copy padding_needed = self._depth_padding_before if cache_x is not None and padding_needed > 0: - print("WanCausalConv3d, cache.shape: ", cache_x.shape) - print("wanCausalConv3d, cache_x min: ", np.min(cache_x)) - print("wanCausalConv3d, cache_x max: ", np.max(cache_x)) # Ensure cache has same spatial/channel dims, potentially different depth assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" cache_len = cache_x.shape[1] @@ -110,34 +105,20 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> padding_needed -= cache_len if padding_needed < 0: - print("wanCausanConv3d, padding_needed < 0") # Cache longer than needed padding, trim from start x = x[:, -padding_needed:, ...] current_padding[1] = (0, 0) # No explicit padding needed now else: # Update depth padding needed - print("wanCausanConv3d, padding_needed > 0") current_padding[1] = (padding_needed, 0) # Apply padding if any dimension requires it padding_to_apply = tuple(current_padding) - print("WanCausalConv3d, before padding x shape: ", x.shape) if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): - print("WanCausalConv3d, applying padding") x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: - print("WanCausalConv3d, NOT applying padding") x_padded = x - - print("WanCausalConv3d, x shape: ", x_padded.shape) - print("wanCausalConv3d, x min: ", np.min(x_padded)) - print("wanCausalConv3d, x max: ", np.max(x_padded)) - # if idx == 12: - # breakpoint() out = self.conv(x_padded) - print("WanCausalConv3d, after conv, x shape: ", out.shape) - print("wanCausalConv3d, x min: ", np.min(out)) - print("wanCausalConv3d, x max: ", np.max(out)) return out @@ -300,8 +281,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": # cache last frame of last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], dim=1) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1) if feat_cache[idx] == "Rep": x = self.time_conv(x) else: @@ -364,13 +345,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - print("Before conv1, idx: ", idx) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv1(x, feat_cache[idx], idx) - # if idx == 4: - # breakpoint() feat_cache[idx] = cache_x feat_idx[0] += 1 else: @@ -379,32 +357,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.norm2(x) x = self.nonlinearity(x) idx = feat_idx[0] - # if idx == 4: - # breakpoint() if feat_cache is not None: idx = feat_idx[0] - print("Residual block, idx: ", idx) - # if idx == 14: - # breakpoint() - print("cache_x min: ", np.min(cache_x)) - print("cache_x max: ", np.max(cache_x)) cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - print("cache_x min: ", np.min(cache_x)) - print("cache_x max: ", np.max(cache_x)) - #breakpoint() x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv2(x) - print("before conv shortcut add: x min", np.min(x)) - print("before conv shortcut add: x max", np.max(x)) x = x + h - print("after conv shortcut add: x min: ", np.min(x)) - print("after conv shortcut add: x max: ", np.max(x)) return x @@ -428,16 +392,8 @@ def __call__(self, x: jax.Array): #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - print("qkv min: ", np.min(qkv)) - print("qkv max: ", np.max(qkv)) #q, k, v = jnp.split(qkv, 3, axis=-1) q, k, v = jnp.split(qkv, 3, axis=-2) - print("q min: ", np.min(q)) - print("q max: ", np.max(q)) - print("k min: ", np.min(k)) - print("k min: ", np.max(k)) - print("v min: ", np.min(v)) - print("v min: ", np.max(v)) q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) v = jnp.transpose(v, (0, 1, 3, 2)) @@ -446,10 +402,8 @@ def __call__(self, x: jax.Array): # output projection x = self.proj(x) - #breakpoint() # Reshape back x = x.reshape(batch_size, time, height, width, channels) - #breakpoint() return x + identity @@ -467,20 +421,11 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity self.resnets = resnets def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - print("WanMidblock...") x = self.resnets[0](x, feat_cache, feat_idx) - print("WanMidBlock resnets[0], x min: ", np.min(x)) - print("WanMidBlock resnets[0], x max: ", np.max(x)) for attn, resnet in zip(self.attentions, self.resnets[1:]): - print("WanMidBlock, for loop, attn len: ", len(self.attentions)) - print("WanMidBlock, for loop, resnets len: ", len(self.resnets)) if attn is not None: x = attn(x) - print("WanMidBlock attn[0], x min: ", np.min(x)) - print("WanMidBlock attn[0], x max: ", np.max(x)) x = resnet(x, feat_cache, feat_idx) - print("WanMidBlock resnets[i], x min: ", np.min(x)) - print("WanMidBlock resnets[i], x max: ", np.max(x)) return x @@ -888,21 +833,18 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - out = jnp.concatenate([out, out_], axis=1) - print("out_.shape: ", out_.shape) - print("out_ min: ", np.min(out_)) - print("out_ max: ", np.max(out_)) - print("out.shape: ", out.shape) - print("out min: ", np.min(out)) - print("out max: ", np.max(out)) - for i in range(len(self._feat_map)): - if isinstance(self._feat_map[i], jax.Array): - print("i: ", i) - print("min: ", np.min(self._feat_map[i])) - print("max: ", np.max(self._feat_map[i])) - else: - print(f"feat_map[{i}] : {self._feat_map[i]}") - # breakpoint() + + # This is to bypass an issue where frame[1] should be frame[2] and vise versa. + # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. + # Most likely due to an incorrect reshaping in the decoder. + fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] + if len(fm1.shape) == 4: + fm1 = jnp.expand_dims(fm1, axis=0) + fm2 = jnp.expand_dims(fm2, axis=0) + fm3 = jnp.expand_dims(fm3, axis=0) + fm4 = jnp.expand_dims(fm4, axis=0) + + out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) out = jnp.clip(out, min=-1.0, max=1.0) self.clear_cache() if not return_dict: From 4245b2419e98441c06f9de6abd4794f56cccfb57 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 16:01:46 +0000 Subject: [PATCH 20/54] cleanup unused code. --- .../models/wan/autoencoder_kl_wan.py | 25 +------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index f9603e51e..e7bbc1e23 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -19,11 +19,10 @@ import jax import jax.numpy as jnp from flax import nnx -from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...configuration_utils import ConfigMixin from ..modeling_flax_utils import FlaxModelMixin from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) -import numpy as np BlockSizes = common_types.BlockSizes CACHE_T = 2 @@ -60,13 +59,6 @@ def __init__( stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, use_bias: bool = True, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", ): self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") self.stride = _canonicalize_tuple(stride, 3, "stride") @@ -191,13 +183,6 @@ def __init__( rngs: nnx.Rngs, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", ): self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) @@ -212,13 +197,6 @@ def __init__( dim: int, mode: str, rngs: nnx.Rngs, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", ): self.dim = dim self.mode = mode @@ -548,7 +526,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_in(x) - # (1, 1, 480, 720, 96) for layer in self.down_blocks: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) From c0ba5c1bc44448d005d33b8fdbd1750aa5b00958 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:08:47 +0000 Subject: [PATCH 21/54] linting --- end_to_end/tpu/eval_assert.py | 3 +- src/maxdiffusion/generate_wan.py | 226 -------------- src/maxdiffusion/models/flux/util.py | 7 +- .../models/modeling_flax_pytorch_utils.py | 5 +- .../models/wan/autoencoder_kl_wan.py | 95 +++--- .../wan/transformers/transformer_flux_wan.py | 287 ------------------ .../transformers/transformer_flux_wan_nnx.py | 69 ----- src/maxdiffusion/models/wan/wan_utils.py | 16 +- src/maxdiffusion/tests/wan_vae_test.py | 82 +++-- src/maxdiffusion/utils/export_utils.py | 144 ++++----- src/maxdiffusion/utils/import_utils.py | 23 +- src/maxdiffusion/utils/loading_utils copy.py | 162 ---------- src/maxdiffusion/utils/loading_utils.py | 153 +++++----- 13 files changed, 283 insertions(+), 989 deletions(-) delete mode 100644 src/maxdiffusion/generate_wan.py delete mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py delete mode 100644 src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py delete mode 100644 src/maxdiffusion/utils/loading_utils copy.py diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py index 07b5585a6..20fd0d8ae 100644 --- a/end_to_end/tpu/eval_assert.py +++ b/end_to_end/tpu/eval_assert.py @@ -22,7 +22,6 @@ """ - # pylint: skip-file """Reads and asserts over target values""" from absl import app @@ -47,7 +46,7 @@ def test_final_loss(metrics_file, target_loss, num_samples_str="10"): target_loss = float(target_loss) num_samples = int(num_samples_str) with open(metrics_file, "r", encoding="utf8") as _: - last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples) + last_n_data = get_last_n_data(metrics_file, "learning/loss", num_samples) avg_last_n_data = sum(last_n_data) / len(last_n_data) print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}") print(f"Target loss is {target_loss}") diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py deleted file mode 100644 index ce33c6806..000000000 --- a/src/maxdiffusion/generate_wan.py +++ /dev/null @@ -1,226 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import html -from typing import Callable, List, Union, Sequence, Optional -import time -import torch -import ftfy -import regex as re -import jax -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P -from flax import nnx -from absl import app -from transformers import AutoTokenizer, UMT5EncoderModel -from maxdiffusion import pyconfig, max_logging -from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWan -from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel -from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline - -from maxdiffusion.max_utils import ( - device_put_replicated, - get_memory_allocations, - create_device_mesh, - get_flash_block_sizes, - get_precision, - setup_initial_state, -) - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text - - -def _get_t5_prompt_embeds( - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - batch_size = len(prompt) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - - prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - -def encode_prompt( - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, - num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - """ - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds = _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - return prompt_embeds, negative_prompt_embeds - - -def run(config): - max_logging.log("Wan 2.1 inference script") - - rng = jax.random.key(config.seed) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - global_batch_size = config.per_device_batch_size * jax.local_device_count() - - tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype - ) - text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder", - ) - s0 = time.perf_counter() - prompt_embeds, negative_prompt_embeds = encode_prompt( - tokenizer=tokenizer, text_encoder=text_encoder, prompt=config.prompt, negative_prompt=config.negative_prompt - ) - max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") - - # pipeline, params = WanPipeline.from_pretrained( - # config.pretrained_model_name_or_path, - # #vae=None, - # #transformer=None - # ) - # breakpoint() - - pipeline, params = WanPipeline.from_pretrained(config.pretrained_model_name_or_path, vae=None, transformer=None) - - # wan_transformer = WanModel(rngs=nnx.Rngs(config.seed)) - - -def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) - - -if __name__ == "__main__": - app.run(main) diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 26371e4db..504b71e9f 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -11,11 +11,7 @@ from jax import numpy as jnp from safetensors import safe_open -from ..modeling_flax_pytorch_utils import ( - rename_key, - rename_key_and_reshape_tensor, - torch2jax -) +from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax) from maxdiffusion import max_logging @@ -36,6 +32,7 @@ class FluxParams: rngs: Array param_dtype: DTypeLike + @dataclass class ModelSpec: params: FluxParams diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 74c7fce50..d6a448f98 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -29,6 +29,7 @@ logger = logging.get_logger(__name__) + def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): """ expected_pytree: dict - a pytree that comes from initializing the model. @@ -54,6 +55,7 @@ def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): else: max_logging.log(f"key: {key} not found...") + def torch2jax(torch_tensor: torch.Tensor) -> Array: is_bfloat16 = torch_tensor.dtype == torch.bfloat16 if is_bfloat16: @@ -67,6 +69,7 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array: jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) return jax_array + def rename_key(key): regex = r"\w+[.]\d+" pats = re.findall(regex, key) @@ -132,7 +135,7 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor - + # 3d conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 5: diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index e7bbc1e23..8325c3707 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -23,6 +23,7 @@ from ..modeling_flax_utils import FlaxModelMixin from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) + BlockSizes = common_types.BlockSizes CACHE_T = 2 @@ -367,10 +368,10 @@ def __call__(self, x: jax.Array): x = self.norm(x) qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - #q, k, v = jnp.split(qkv, 3, axis=-1) + # q, k, v = jnp.split(qkv, 3, axis=-1) q, k, v = jnp.split(qkv, 3, axis=-2) q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) @@ -662,6 +663,32 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): return x +class AutoencoderKLWanCache: + + def __init__(self, module): + self.module = module + self.clear_cache() + + def clear_cache(self): + """Resets cache dictionaries and indices""" + + def _count_conv3d(module): + count = 0 + node_types = nnx.graph.iter_graph([module]) + for _, value in node_types: + if isinstance(value, WanCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.module.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.module.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): def __init__( @@ -745,29 +772,9 @@ def __init__( temperal_upsample=self.temporal_upsample, dropout=dropout, ) - self.clear_cache() - def clear_cache(self): - """Resets cache dictionaries and indices""" - - def _count_conv3d(module): - count = 0 - node_types = nnx.graph.iter_graph([module]) - for path, value in node_types: - if isinstance(value, WanCausalConv3d): - count += 1 - return count - - self._conv_num = _count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode - self._enc_conv_num = _count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num - - def _encode(self, x: jax.Array): - self.clear_cache() + def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): + feat_cache.clear_cache() if x.shape[-1] != 3: # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) @@ -776,42 +783,46 @@ def _encode(self, x: jax.Array): t = x.shape[1] iter_ = 1 + (t - 1) // 4 for i in range(iter_): - self._enc_conv_idx = [0] + feat_cache._enc_conv_idx = [0] if i == 0: - out = self.encoder(x[:, :1, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + out = self.encoder(x[:, :1, :, :, :], feat_cache=feat_cache._enc_feat_map, feat_idx=feat_cache._enc_conv_idx) else: out_ = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], + feat_cache=feat_cache._enc_feat_map, + feat_idx=feat_cache._enc_conv_idx, ) out = jnp.concatenate([out, out_], axis=1) enc = self.quant_conv(out) mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] enc = jnp.concatenate([mu, logvar], axis=-1) - self.clear_cache() + feat_cache.clear_cache() return enc def encode( - self, x: jax.Array, return_dict: bool = True + self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """Encode video into latent distribution.""" - h = self._encode(x) + h = self._encode(x, feat_cache) posterior = FlaxDiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: - self.clear_cache() + def _decode( + self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True + ) -> Union[FlaxDecoderOutput, jax.Array]: + feat_cache.clear_cache() iter_ = z.shape[1] x = self.post_quant_conv(z) for i in range(iter_): - self._conv_idx = [0] + feat_cache._conv_idx = [0] if i == 0: - out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) else: - out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - - # This is to bypass an issue where frame[1] should be frame[2] and vise versa. + out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + + # This is to bypass an issue where frame[1] should be frame[2] and vise versa. # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. # Most likely due to an incorrect reshaping in the decoder. fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] @@ -820,21 +831,23 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu fm2 = jnp.expand_dims(fm2, axis=0) fm3 = jnp.expand_dims(fm3, axis=0) fm4 = jnp.expand_dims(fm4, axis=0) - + out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) out = jnp.clip(out, min=-1.0, max=1.0) - self.clear_cache() + feat_cache.clear_cache() if not return_dict: return (out,) return FlaxDecoderOutput(sample=out) - def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]: + def decode( + self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True + ) -> Union[FlaxDecoderOutput, jax.Array]: if z.shape[-1] != self.z_dim: # reshape channel last for JAX z = jnp.transpose(z, (0, 2, 3, 4, 1)) assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}" - decoded = self._decode(z).sample + decoded = self._decode(z, feat_cache).sample if not return_dict: return (decoded,) return FlaxDecoderOutput(sample=decoded) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py deleted file mode 100644 index 5cd83bbbd..000000000 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan.py +++ /dev/null @@ -1,287 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -from typing import Tuple, Dict, Optional, Any, Union -import jax -import math -import jax.numpy as jnp -from chex import Array -import flax.linen as nn - -from ...attention_flax import FlaxFeedForward, Fla -from ...embeddings_flax import (get_1d_rotary_pos_embed, FlaxTimesteps, FlaxTimestepEmbedding, PixArtAlphaTextProjection) - -from ....configuration_utils import ConfigMixin, flax_register_to_config -from ...modeling_flax_utils import FlaxModelMixin - - -class WanRotaryPosEmbed(nn.Module): - attention_head_dim: int - patch_size: Tuple[int, int, int] - theta: float = 10000.0 - max_seq_len: int - - @nn.compact - def __call__(self, hidden_states: Array) -> Array: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t, p_h, p_w = self.patch_size - ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - - h_dim = w_dim = 2 * (self.attention_head_dim // 6) - t_dim = self.attention_head_dim - h_dim - w_dim - - freqs = [] - for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed(dim, self.max_seq_length, self.theta, freqs_dtype=jnp.float64) - freqs.append(freq) - self.freqs = jnp.concatenate(freqs, dim=1) - - sizes = [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6, - ] - cumulative_sizes = jnp.cumsum(jnp.array(sizes)) - split_indices = cumulative_sizes[:-1] - freqs_split = jnp.split(freqs, split_indices, axis=1) - - freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) - freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) - - freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) - freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) - - freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) - freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) - - freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) - freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) - - return freqs_final - - -class WanImageEmbeddings(nn.Module): - out_features: int - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, encoder_hidden_states_image: Array) -> Array: - hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, - )(encoder_hidden_states_image) - hidden_states = FlaxFeedForward( - self.out_features, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision - )(hidden_states) - hidden_states = nn.LayerNorm( - dtype=jnp.float32, - param_dtype=jnp.float32, - )(hidden_states) - return hidden_states - - -class WanTimeTextImageEmbeddings(nn.Module): - dim: int - time_freq_dim: int - time_proj_dim: int - text_embed_dim: int - image_embed_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, timestep: Array, encoder_hidden_states: Array, encoder_hidden_states_image: Array) -> Array: - - timestep = FlaxTimesteps( - dim=self.time_freq_dim, - flip_sin_to_cos=True, - freq_shift=0, - )(timestep) - temb = FlaxTimestepEmbedding(time_embed_dim=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype)(timestep) - timestep_proj = nn.Dense( - self.time_proj_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - )(nn.silu(temb)) - encoder_hidden_states = PixArtAlphaTextProjection( - hidden_size=self.dim, - act_fn="gelu_tanh", - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - )(encoder_hidden_states) - - if encoder_hidden_states_image is not None: - encoder_hidden_states_image = WanImageEmbeddings( - out_features=self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision - )(encoder_hidden_states_image) - - return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image - - -class WanTransformerBlock(nn.Module): - dim: int - ffn_dim: int - num_heads: int - qk_norm: str = "rms_norm_across_heads" - cross_attn_norm: bool = False - eps: float = 1e-6 - added_kv_proj_dim: Optional[int] = None - - @nn.compact - def __call__(self, hidden_states: Array, encoder_hidden_states: Array, temb: Array, rotary_emb: Array): - - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 - ) - - # 1. Self-attention - norm_hidden_states = ( - nn.LayerNorm( - epsilon=self.eps, - use_bias=False, - use_scale=False, - dtype=jnp.float32, - param_dtype=jnp.float32, - )(hidden_states.astype(jnp.float32)) - * (1 + scale_msa) - + shift_msa - ).astype(hidden_states.dtype) - attn_output = FlaxWanAttention( - query_dim=self.dim, - heads=self.num_heads, - dim_head=self.dim // self.num_heads, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - attention_kernel=self.attention_kernel, - mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes, - ) - - -class WanTransformer3dModel(nn.Module, FlaxModelMixin, ConfigMixin): - r""" - A Transformer model for video-like data used in the Wan model. - - Args: - patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): - 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). - num_attention_heads (`int`, defaults to `40`): - Fixed length for text embeddings. - attention_head_dim (`int`, defaults to `128`): - The number of channels in each head. - in_channels (`int`, defaults to `16`): - The number of channels in the input. - out_channels (`int`, defaults to `16`): - The number of channels in the output. - text_dim (`int`, defaults to `512`): - Input dimension for text embeddings. - freq_dim (`int`, defaults to `256`): - Dimension for sinusoidal time embeddings. - ffn_dim (`int`, defaults to `13824`): - Intermediate dimension in feed-forward network. - num_layers (`int`, defaults to `40`): - The number of layers of transformer blocks to use. - window_size (`Tuple[int]`, defaults to `(-1, -1)`): - Window size for local attention (-1 indicates global attention). - cross_attn_norm (`bool`, defaults to `True`): - Enable cross-attention normalization. - qk_norm (`bool`, defaults to `True`): - Enable query/key normalization. - eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - add_img_emb (`bool`, defaults to `False`): - Whether to use img_emb. - added_kv_proj_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the added key and value projections. If `None`, no projection is used. - """ - - patch_size: Tuple[int] = (1, 2, 2) - num_attention_heads: int = 40 - attention_head_dim: int = 128 - in_channels: int = 16 - out_channels: int = 16 - text_dim: int = 4096 - freq_dim: int = 256 - ffn_dim: int = 13824 - num_layers: int = 40 - cross_attn_norm: bool = True - qk_norm: Optional[str] = "rms_norm_across_heads" - eps: float = 1e-6 - image_dim: Optional[int] = None - added_kv_proj_dim: Optional[int] = None - rope_max_seq_len: int = 1024 - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - attention: str = "dot_product" - - @nn.compact - def __call__( - self, - hidden_states: Array, - timestep: Array, - encoder_hidden_states: Array, - encoder_hidden_states_image: Optional[Array] = None, - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> Union[Array, Dict[str, Array]]: - - inner_dim = self.num_attention_heads * self.attention_head_dim - batch_size, num_channels, num_frames, height, width = hidden_states.shape - - p_t, p_h, p_w = self.config.patch_size - post_patch_num_frames = num_frames // p_t - post_patch_height = height // p_h - post_patch_width = width // p_w - - # 1. Patch & position embedding - rotary_emb = WanRotaryPosEmbed( - attention_head_dim=self.attention_head_dim, patch_size=self.patch_size, max_seq_len=self.rope_max_seq_len - )(hidden_states) - hidden_states = nn.Conv( - features=inner_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - )(hidden_states) - flattened_shape = (batch_size, num_channels, -1) # TODO is his num_channels or frames? - flattened = hidden_states.reshape(flattened_shape) - transposed = jnp.transpose(flattened, (0, 2, 1)) - - # 2. Condition embeddings - # image_embedding_dim=1280 for I2V model - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = WanTimeTextImageEmbeddings( - dim=inner_dim, - time_freq_dim=self.freq_dim, - time_proj_dim=inner_dim * 6, - text_embed_dim=self.text_dim, - image_embed_dim=self.image_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - )(timestep, encoder_hidden_states, encoder_hidden_states_image) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py b/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py deleted file mode 100644 index eedba51cc..000000000 --- a/src/maxdiffusion/models/wan/transformers/transformer_flux_wan_nnx.py +++ /dev/null @@ -1,69 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import jax -import jax.numpy as jnp -from flax import nnx -from .... import common_types, max_logging -from ...modeling_flax_utils import FlaxModelMixin -from ....configuration_utils import ConfigMixin - -BlockSizes = common_types.BlockSizes - - -class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): - - def __init__( - self, - rngs: nnx.Rngs, - model_type="t2v", - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2038, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - eps=1e-6, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", - ): - self.path_embedding = nnx.Conv( - in_dim, - dim, - kernel_size=patch_size, - strides=patch_size, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)), - rngs=rngs, - ) - - def __call__(self, x): - x = self.path_embedding(x) - return x diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 2a8517954..b39388089 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -4,12 +4,8 @@ from huggingface_hub import hf_hub_download from safetensors import safe_open from flax.traverse_util import flatten_dict, unflatten_dict -from ..modeling_flax_pytorch_utils import ( - rename_key, - rename_key_and_reshape_tensor, - torch2jax, - validate_flax_state_dict -) +from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) + def _tuple_str_to_int(in_tuple): out_list = [] @@ -25,7 +21,9 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: device = jax.devices(device)[0] with jax.default_device(device): if hf_download: - ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors") + ckpt_path = hf_hub_download( + pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors" + ) max_logging.log(f"Load and port Wan 2.1 VAE on {device}") if ckpt_path is not None: @@ -73,5 +71,5 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: jax.clear_caches() else: raise FileNotFoundError(f"Path {ckpt_path} was not found") - - return flax_state_dict \ No newline at end of file + + return flax_state_dict diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index e71ddbde2..5e37506d0 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -14,7 +14,7 @@ limitations under the License. """ -import os +import functools import torch import torch.nn as nn import torch.nn.functional as F @@ -25,6 +25,7 @@ import unittest import pytest from absl.testing import absltest +from skimage.metrics import structural_similarity as ssim from ..models.wan.autoencoder_kl_wan import ( WanCausalConv3d, WanUpsample, @@ -35,11 +36,15 @@ WanResample, ZeroPaddedConv2D, WanAttentionBlock, + AutoencoderKLWanCache, ) from ..models.wan.wan_utils import load_wan_vae +from ..utils import load_video +from ..video_processor import VideoProcessor CACHE_T = 2 + class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer. @@ -91,19 +96,12 @@ def __init__(self, dim: int, mode: str) -> None: WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) ) elif mode == "upsample3d": - # self.resample = nn.Sequential( - # WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - # ) - # self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) raise Exception("downsample3d not supported") elif mode == "downsample2d": self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": raise Exception("downsample3d not supported") - #self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - #self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - else: self.resample = nn.Identity() @@ -156,12 +154,6 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} - def test_clear_cache(self): - key = jax.random.key(0) - rngs = nnx.Rngs(key) - wan_vae = AutoencoderKLWan(rngs=rngs) - wan_vae.clear_cache() - def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -379,7 +371,7 @@ def test_wan_decode(self): attn_scales=attn_scales, temperal_downsample=temperal_downsample, ) - + vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 t = 13 channels = 16 @@ -391,7 +383,7 @@ def test_wan_decode(self): latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) input = input / latents_std + latents_mean - dummy_output = wan_vae.decode(input) + dummy_output = wan_vae.decode(input, feat_cache=vae_cache) assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) def test_wan_encode(self): @@ -412,6 +404,7 @@ def test_wan_encode(self): attn_scales=attn_scales, temperal_downsample=temperal_downsample, ) + vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 channels = 3 t = 49 @@ -419,22 +412,51 @@ def test_wan_encode(self): width = 720 input_shape = (batch, channels, t, height, width) input = jnp.ones(input_shape) - output = wan_vae.encode(input) + output = wan_vae.encode(input, feat_cache=vae_cache) assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) - # def test_load_checkpoint(self): - # pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" - # key = jax.random.key(0) - # rngs = nnx.Rngs(key) - # wan_vae = AutoencoderKLWan.from_config( - # pretrained_model_name_or_path, - # subfolder="vae", - # rngs=rngs - # ) - # graphdef, state = nnx.split(wan_vae) - # params = state.to_pure_dict() - # # This replaces random params with the model. - # params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + def test_load_checkpoint(self): + def vae_encode(video, wan_vae, vae_cache, key): + latent = wan_vae.encode(video, feat_cache=vae_cache) + latent = latent.latent_dist.sample(key) + return latent + + pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + key = jax.random.key(0) + rngs = nnx.Rngs(key) + wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs) + vae_cache = AutoencoderKLWanCache(wan_vae) + video_path, fps = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4", 8 + video = load_video(video_path) + + vae_scale_factor_spatial = 2 ** len(wan_vae.temperal_downsample) + video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial) + width, height = video[0].size + video = video_processor.preprocess_video(video, height=height, width=width) # .to(dtype=jnp.float32) + original_video = jnp.array(np.array(video), dtype=jnp.bfloat16) + + graphdef, state = nnx.split(wan_vae) + params = state.to_pure_dict() + # This replaces random params with the model. + params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + wan_vae = nnx.merge(graphdef, params) + + p_vae_encode = jax.jit(functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)) + original_video_shape = original_video.shape + latent = p_vae_encode(original_video) + + jitted_decode = jax.jit(functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)) + video = jitted_decode(latent)[0] + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + assert video.shape == original_video_shape + + original_video = torch.from_numpy(np.array(original_video.astype(jnp.float32))).to(dtype=torch.bfloat16) + video = torch.from_numpy(np.array(video)).to(dtype=torch.bfloat16) + video = video_processor.postprocess_video(video, output_type="np") + original_video = video_processor.postprocess_video(original_video, output_type="np") + ssim_compare = ssim(video[0], original_video[0], multichannel=True, channel_axis=-1, data_range=255) + assert ssim_compare >= 0.9999 if __name__ == "__main__": diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index 461922e94..5dfa3562f 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -110,30 +110,32 @@ def export_to_obj(mesh, output_obj_path: str = None): with open(output_obj_path, "w") as f: f.writelines("\n".join(combined_data)) + def _legacy_export_to_video( video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 ): - if is_opencv_available(): - import cv2 - else: - raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - if isinstance(video_frames[0], np.ndarray): - video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] - elif isinstance(video_frames[0], PIL.Image.Image): - video_frames = [np.array(frame) for frame in video_frames] + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - h, w, c = video_frames[0].shape - video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) - for i in range(len(video_frames)): - img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) - video_writer.write(img) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + + return output_video_path - return output_video_path def export_to_video( video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], @@ -143,64 +145,64 @@ def export_to_video( bitrate: Optional[int] = None, macro_block_size: Optional[int] = 16, ) -> str: - """ - quality: - Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to - prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. - Specifying a fixed bitrate using `bitrate` disables this parameter. - - bitrate: - Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. - Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter - rather than specifiying a fixed bitrate with this parameter. - - macro_block_size: - Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number - imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs - are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic - feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some - codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. - """ - # TODO: Dhruv. Remove by Diffusers release 0.33.0 - # Added to prevent breaking existing code - if not is_imageio_available(): - logger.warning( - ( - "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" - "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" - "Support for the OpenCV backend will be deprecated in a future Diffusers version" - ) + """ + quality: + Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to + prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. + Specifying a fixed bitrate using `bitrate` disables this parameter. + + bitrate: + Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. + Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter + rather than specifiying a fixed bitrate with this parameter. + + macro_block_size: + Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number + imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs + are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic + feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some + codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. + """ + # TODO: Dhruv. Remove by Diffusers release 0.33.0 + # Added to prevent breaking existing code + if not is_imageio_available(): + logger.warning( + ( + "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" + "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" + "Support for the OpenCV backend will be deprecated in a future Diffusers version" ) - return _legacy_export_to_video(video_frames, output_video_path, fps) - - if is_imageio_available(): - import imageio - else: - raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) - - try: - imageio.plugins.ffmpeg.get_exe() - except AttributeError: - raise AttributeError( - ( - "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" - "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" - ) + ) + return _legacy_export_to_video(video_frames, output_video_path, fps) + + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + ( + "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" + "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" ) + ) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - if isinstance(video_frames[0], np.ndarray): - video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] - elif isinstance(video_frames[0], PIL.Image.Image): - video_frames = [np.array(frame) for frame in video_frames] + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] - with imageio.get_writer( - output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size - ) as writer: - for frame in video_frames: - writer.append_data(frame) + with imageio.get_writer( + output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size + ) as writer: + for frame in video_frames: + writer.append_data(frame) - return output_video_path + return output_video_path diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index 224bad16c..d83596e8d 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -51,18 +51,20 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + def _is_package_available(pkg_name: str): - pkg_exists = importlib.util.find_spec(pkg_name) is not None - pkg_version = "N/A" + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False - if pkg_exists: - try: - pkg_version = importlib_metadata.version(pkg_name) - logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") - except (ImportError, importlib_metadata.PackageNotFoundError): - pkg_exists = False + return pkg_exists, pkg_version - return pkg_exists, pkg_version _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: @@ -298,9 +300,11 @@ def _is_package_available(pkg_name: str): except importlib_metadata.PackageNotFoundError: _peft_available = False + def is_imageio_available(): return _imageio_available + def is_torch_available(): return _torch_available @@ -727,4 +731,3 @@ def _get_module(self, module_name: str): def __reduce__(self): return (self.__class__, (self._name, self.__file__, self._import_structure)) - diff --git a/src/maxdiffusion/utils/loading_utils copy.py b/src/maxdiffusion/utils/loading_utils copy.py deleted file mode 100644 index fd66aaa4d..000000000 --- a/src/maxdiffusion/utils/loading_utils copy.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import tempfile -from typing import Any, Callable, List, Optional, Tuple, Union -from urllib.parse import unquote, urlparse - -import PIL.Image -import PIL.ImageOps -import requests - -from .import_utils import BACKENDS_MAPPING, is_imageio_available - - -def load_image( - image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None -) -> PIL.Image.Image: - """ - Loads `image` to a PIL Image. - - Args: - image (`str` or `PIL.Image.Image`): - The image to convert to the PIL Image format. - convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): - A conversion method to apply to the image after loading it. When set to `None` the image will be converted - "RGB". - - Returns: - `PIL.Image.Image`: - A PIL Image. - """ - if isinstance(image, str): - if image.startswith("http://") or image.startswith("https://"): - image = PIL.Image.open(requests.get(image, stream=True).raw) - elif os.path.isfile(image): - image = PIL.Image.open(image) - else: - raise ValueError( - f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." - ) - elif isinstance(image, PIL.Image.Image): - image = image - else: - raise ValueError( - "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." - ) - - image = PIL.ImageOps.exif_transpose(image) - - if convert_method is not None: - image = convert_method(image) - else: - image = image.convert("RGB") - - return image - - -def load_video( - video: str, - convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, -) -> List[PIL.Image.Image]: - """ - Loads `video` to a list of PIL Image. - - Args: - video (`str`): - A URL or Path to a video to convert to a list of PIL Image format. - convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): - A conversion method to apply to the video after loading it. When set to `None` the images will be converted - to "RGB". - - Returns: - `List[PIL.Image.Image]`: - The video as a list of PIL images. - """ - is_url = video.startswith("http://") or video.startswith("https://") - is_file = os.path.isfile(video) - was_tempfile_created = False - - if not (is_url or is_file): - raise ValueError( - f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." - ) - - if is_url: - response = requests.get(video, stream=True) - if response.status_code != 200: - raise ValueError(f"Failed to download video. Status code: {response.status_code}") - - parsed_url = urlparse(video) - file_name = os.path.basename(unquote(parsed_url.path)) - - suffix = os.path.splitext(file_name)[1] or ".mp4" - video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name - - was_tempfile_created = True - - video_data = response.iter_content(chunk_size=8192) - with open(video_path, "wb") as f: - for chunk in video_data: - f.write(chunk) - - video = video_path - - pil_images = [] - if video.endswith(".gif"): - gif = PIL.Image.open(video) - try: - while True: - pil_images.append(gif.copy()) - gif.seek(gif.tell() + 1) - except EOFError: - pass - - else: - if is_imageio_available(): - import imageio - else: - raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) - - try: - imageio.plugins.ffmpeg.get_exe() - except AttributeError: - raise AttributeError( - "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" - ) - - with imageio.get_reader(video) as reader: - # Read all frames - for frame in reader: - pil_images.append(PIL.Image.fromarray(frame)) - - if was_tempfile_created: - os.remove(video_path) - - if convert_method is not None: - pil_images = convert_method(pil_images) - - return pil_images - - -# Taken from `transformers`. -def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: - if "." in tensor_name: - splits = tensor_name.split(".") - for split in splits[:-1]: - new_module = getattr(module, split) - if new_module is None: - raise ValueError(f"{module} has no attribute {split}.") - module = new_module - tensor_name = splits[-1] - return module, tensor_name - - -def get_submodule_by_name(root_module, module_path: str): - current = root_module - parts = module_path.split(".") - for part in parts: - if part.isdigit(): - idx = int(part) - current = current[idx] # e.g., for nn.ModuleList or nn.Sequential - else: - current = getattr(current, part) - return current diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 77caa39e0..4a480b3c7 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -37,85 +37,86 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = image.convert("RGB") return image + def load_video( video: str, convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, ) -> List[PIL.Image.Image]: - """ - Loads `video` to a list of PIL Image. - - Args: - video (`str`): - A URL or Path to a video to convert to a list of PIL Image format. - convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): - A conversion method to apply to the video after loading it. When set to `None` the images will be converted - to "RGB". - - Returns: - `List[PIL.Image.Image]`: - The video as a list of PIL images. - """ - is_url = video.startswith("http://") or video.startswith("https://") - is_file = os.path.isfile(video) - was_tempfile_created = False - - if not (is_url or is_file): - raise ValueError( - f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." - ) - - if is_url: - response = requests.get(video, stream=True) - if response.status_code != 200: - raise ValueError(f"Failed to download video. Status code: {response.status_code}") - - parsed_url = urlparse(video) - file_name = os.path.basename(unquote(parsed_url.path)) - - suffix = os.path.splitext(file_name)[1] or ".mp4" - video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name - - was_tempfile_created = True - - video_data = response.iter_content(chunk_size=8192) - with open(video_path, "wb") as f: - for chunk in video_data: - f.write(chunk) - - video = video_path - - pil_images = [] - if video.endswith(".gif"): - gif = PIL.Image.open(video) - try: - while True: - pil_images.append(gif.copy()) - gif.seek(gif.tell() + 1) - except EOFError: - pass + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio else: - if is_imageio_available(): - import imageio - else: - raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) - - try: - imageio.plugins.ffmpeg.get_exe() - except AttributeError: - raise AttributeError( - "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" - ) - - with imageio.get_reader(video) as reader: - # Read all frames - for frame in reader: - pil_images.append(PIL.Image.fromarray(frame)) - - if was_tempfile_created: - os.remove(video_path) - - if convert_method is not None: - pil_images = convert_method(pil_images) - - return pil_images \ No newline at end of file + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images From bab4d1704944c23e451d8fdc0c78d9e4b5de0de6 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:10:30 +0000 Subject: [PATCH 22/54] remove wan from readme. --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index 2dc6523c4..e14603ac4 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ MaxDiffusion supports - [Training](#training) - [Dreambooth](#dreambooth) - [Inference](#inference) - - [Wan 2.1](#wan) - [Flux](#flux) - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) @@ -172,12 +171,6 @@ To generate images, run the following command: ```bash python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` - - ## Wan - - ```bash - python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/" - ``` ## Flux From 7a8daedd71051850a9dd9ec547521c1a44cc95fd Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:13:23 +0000 Subject: [PATCH 23/54] remove unused files. --- src/maxdiffusion/pipelines/wan/__init__.py | 0 .../pipelines/wan/pipeline_wan.py | 84 ------------------- ...heduling_flow_match_euler_discrete_flax.py | 74 ---------------- 3 files changed, 158 deletions(-) delete mode 100644 src/maxdiffusion/pipelines/wan/__init__.py delete mode 100644 src/maxdiffusion/pipelines/wan/pipeline_wan.py delete mode 100644 src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/maxdiffusion/pipelines/wan/pipeline_wan.py b/src/maxdiffusion/pipelines/wan/pipeline_wan.py deleted file mode 100644 index 27b5844a8..000000000 --- a/src/maxdiffusion/pipelines/wan/pipeline_wan.py +++ /dev/null @@ -1,84 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from typing import Union, List -from transformers import AutoTokenizer, UMT5EncoderModel -import torch -from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel -from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan -from ..pipeline_flax_utils import FlaxDiffusionPipeline -from ...video_processor import VideoProcessor -# from ...schedulers import FlowMatchEulerDiscreteScheduler - - -class WanPipeline(FlaxDiffusionPipeline): - - def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - transformer: WanModel, - vae: AutoencoderKLWan, - # scheduler: FlowMatchEulerDiscreteScheduler, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - # scheduler=scheduler - ) - - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - - def _get_t5_prompt_embds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - - prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state - # prompt_embeds = prompt_embeds.to(dtype=dtype) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py deleted file mode 100644 index ec29a353e..000000000 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_euler_discrete_flax.py +++ /dev/null @@ -1,74 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax.numpy as jnp - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import ( - CommonSchedulerState, - # FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - broadcast_to_shape_from_left, -) - - -@flax.struct.dataclass -class FlowMatchEulerDiscreteSchedulerState: - common: CommonSchedulerState - - -@dataclass -class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): - state: FlowMatchEulerDiscreteSchedulerState - - -class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): - # _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] - - dtype: jnp.dtype - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - shift: float = 1.0, - use_dynamic_shifting: bool = False, - base_shift: Optional[float] = 0.5, - max_shift: Optional[float] = 1.15, - base_image_seq_len: Optional[int] = 256, - max_image_seq_len: Optional[int] = 4096, - invert_sigmas: bool = False, - shift_terminal: Optional[float] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - time_shift_type: str = "exponential", - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) From b31b4ad9b105004a9388c5fcc32ab26bef76b24b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 8 May 2025 21:36:41 +0000 Subject: [PATCH 24/54] more linter fixes. --- src/maxdiffusion/models/flux/util.py | 1 - src/maxdiffusion/models/wan/wan_utils.py | 4 ++-- src/maxdiffusion/tests/wan_vae_test.py | 9 +++------ src/maxdiffusion/utils/loading_utils.py | 2 +- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 504b71e9f..d856fda5d 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -4,7 +4,6 @@ import jax from jax.typing import DTypeLike -import torch # need for torch 2 jax from chex import Array from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import hf_hub_download diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index b39388089..1a9948fdb 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -3,7 +3,7 @@ from maxdiffusion import max_logging from huggingface_hub import hf_hub_download from safetensors import safe_open -from flax.traverse_util import flatten_dict, unflatten_dict +from flax.traverse_util import unflatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) @@ -12,7 +12,7 @@ def _tuple_str_to_int(in_tuple): for item in in_tuple: try: out_list.append(int(item)) - except: + except ValueError: out_list.append(item) return tuple(out_list) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 5e37506d0..7d750c8bb 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -23,7 +23,6 @@ from flax import nnx import numpy as np import unittest -import pytest from absl.testing import absltest from skimage.metrics import structural_similarity as ssim from ..models.wan.autoencoder_kl_wan import ( @@ -172,7 +171,7 @@ def test_wanrms_norm(self): dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) - assert np.allclose(output_np, torch_output_np) == True + assert np.allclose(output_np, torch_output_np) is True # --- Test Case 2: images == False --- model = TorchWanRMS_norm(dim, images=False) @@ -186,7 +185,7 @@ def test_wanrms_norm(self): dummy_input = jnp.ones(input_shape) output = wanrms_norm(dummy_input) output_np = np.array(output) - assert np.allclose(output_np, torch_output_np) == True + assert np.allclose(output_np, torch_output_np) is True def test_zero_padded_conv(self): @@ -235,8 +234,6 @@ def test_wan_resample(self): w = 720 mode = "downsample2d" input_shape = (batch, dim, t, h, w) - expected_output_shape = (1, dim, 1, 240, 360) - # output dim should be (1, 96, 1, 480, 720) dummy_input = torch.ones(input_shape) torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) @@ -426,7 +423,7 @@ def vae_encode(video, wan_vae, vae_cache, key): rngs = nnx.Rngs(key) wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs) vae_cache = AutoencoderKLWanCache(wan_vae) - video_path, fps = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4", 8 + video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) vae_scale_factor_spatial = 2 ** len(wan_vae.temperal_downsample) diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 4a480b3c7..6107272c7 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import PIL.Image import PIL.ImageOps From d449d1f079856f3504c94d15f9dbedb45375741e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 9 May 2025 15:39:48 +0000 Subject: [PATCH 25/54] add WanRotaryPosEmbed --- src/maxdiffusion/models/embeddings_flax.py | 20 ++- .../models/wan/transformers/__init__.py | 15 ++ .../wan/transformers/transformer_wan.py | 164 ++++++++++++++++++ .../tests/wan_transformer_test.py | 47 +++++ 4 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 src/maxdiffusion/models/wan/transformers/__init__.py create mode 100644 src/maxdiffusion/models/wan/transformers/transformer_wan.py create mode 100644 src/maxdiffusion/tests/wan_transformer_test.py diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index cc961e131..34f8b3dd2 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -102,7 +102,13 @@ def __call__(self, timesteps): def get_1d_rotary_pos_embed( - dim: int, pos: Union[jnp.array, int], theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32 + dim: int, + pos: Union[jnp.array, int], + theta: float = 10000.0, + linear_factor=1.0, + ntk_factor=1.0, + freqs_dtype=jnp.float32, + use_real: bool = True ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -115,10 +121,14 @@ def get_1d_rotary_pos_embed( theta = theta * ntk_factor freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor freqs = jnp.outer(pos, freqs) - freqs_cos = jnp.cos(freqs) - freqs_sin = jnp.sin(freqs) - out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1) - + if use_real: + # Flux + freqs_cos = jnp.cos(freqs) + freqs_sin = jnp.sin(freqs) + out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1) + else: + # Wan 2.1 + out = jax.lax.complex(jnp.ones_like(freqs), freqs) return out diff --git a/src/maxdiffusion/models/wan/transformers/__init__.py b/src/maxdiffusion/models/wan/transformers/__init__.py new file mode 100644 index 000000000..522c1e64b --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" \ No newline at end of file diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py new file mode 100644 index 000000000..60f956845 --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -0,0 +1,164 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Tuple, Optional +import jax +import jax.numpy as jnp +from flax import nnx +from .... import common_types, max_logging +from ...modeling_flax_utils import FlaxModelMixin +from ....configuration_utils import ConfigMixin +from ...embeddings_flax import get_1d_rotary_pos_embed + +BlockSizes = common_types.BlockSizes + +class WanRotaryPosEmbed(nnx.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0 + ): + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed( + dim, + self.max_seq_len, + theta, + freqs_dtype=jnp.float64, + use_real=False + ) + freqs.append(freq) + self.freqs = jnp.concatenate(freqs, axis=1) + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + _, num_frames, height, width, _ = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + sizes = [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ] + cumulative_sizes = jnp.cumsum(jnp.array(sizes)) + split_indices = cumulative_sizes[:-1] + freqs_split = jnp.split(self.freqs, split_indices, axis=1) + + freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) + freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) + + freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) + freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) + + freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) + freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) + + freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) + freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) + return freqs_final + + +class WanTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin): + def __init__( + self, + rngs: nnx.Rngs, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + #1. Patch & position embedding + self.rope = WanRotaryPosEmbed( + attention_head_dim, + patch_size, + rope_max_seq_len + ) + + +class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): + + def __init__( + self, + rngs: nnx.Rngs, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + self.path_embedding = nnx.Conv( + in_dim, + dim, + kernel_size=patch_size, + strides=patch_size, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)), + rngs=rngs, + ) + + def __call__(self, x): + x = self.path_embedding(x) + return x \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py new file mode 100644 index 000000000..f4003c7bd --- /dev/null +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -0,0 +1,47 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import jax +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from flax import nnx + +from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed + +class WanTransformerTest(unittest.TestCase): + def setUp(self): + WanTransformerTest.dummy_data = {} + + def test_rotary_pos_embed(self): + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed( + attention_head_dim=128, + patch_size=[1, 2, 2], + max_seq_len=1024 + ) + dummy_output = wan_rot_embed(dummy_hidden_states) + assert dummy_output.shape == (1, 1, 75600, 64) + # output shape should be torch.Size([1, 1, 75600, 64]) + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 064fc5fe9c05a75aacb785aebc262f6379f159f3 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 9 May 2025 19:58:22 +0000 Subject: [PATCH 26/54] add nnx classes for timestep embeddings and timesteps. --- src/maxdiffusion/models/embeddings_flax.py | 101 ++++++++++++- .../models/modeling_flax_utils.py | 7 + .../models/wan/autoencoder_kl_wan.py | 11 +- .../wan/transformers/transformer_wan.py | 133 ++++++++++++++---- .../tests/wan_transformer_test.py | 15 +- 5 files changed, 229 insertions(+), 38 deletions(-) diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 34f8b3dd2..3744bac79 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import math - +from typing import Optional import flax.linen as nn +from flax import nnx import jax.numpy as jnp from typing import List, Union import jax +from .modeling_flax_utils import get_activation def get_sinusoidal_embeddings( @@ -56,6 +58,86 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal +class NNXTimestepEmbedding(nnx.Module): + r""" + Time step Embedding Module. Learns embeddings for input time steps. + + Args: + time_embed_dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int, + time_embed_dim: int = 32, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim: int = None, + sample_proj_bias=True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + self.linear_1 = nnx.Linear( + rngs=rngs, + in_features=in_channels, + out_features=time_embed_dim, + use_bias=sample_proj_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + + if cond_proj_dim is not None: + self.cond_proj = nnx.Linear( + rngs=rngs, + ) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + + self.linear_2 = nnx.Linear( + rngs=rngs, + in_features=time_embed_dim, + out_features=time_embed_dim_out, + use_bias=sample_proj_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def __call__(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + class FlaxTimestepEmbedding(nn.Module): r""" @@ -79,6 +161,23 @@ def __call__(self, temb): temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_2")(temb) return temb +class NNXFlaxTimesteps(nnx.Module): + def __init__( + self, + dim: int = 32, + flip_sin_to_cos: bool = False, + freq_shift: float = 1.0, + scale: int = 1, + ): + self.dim = dim + self.flip_sin_to_cos = flip_sin_to_cos + self.freq_shift = freq_shift + self.scale = scale + + def __call__(self, timesteps): + return get_sinusoidal_embeddings( + timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift + ) class FlaxTimesteps(nn.Module): r""" diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index b93ba8396..f22f8d925 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -43,6 +43,13 @@ logger = logging.get_logger(__name__) +_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish} + +def get_activation(name: str): + func = _ACTIVATIONS.get(name) + if func is None: + raise ValueError(f"Unknown activation function: {name}") + return func class FlaxModelMixin(PushToHubMixin): r""" diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 8325c3707..9c92e2ee2 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -20,7 +20,7 @@ import jax.numpy as jnp from flax import nnx from ...configuration_utils import ConfigMixin -from ..modeling_flax_utils import FlaxModelMixin +from ..modeling_flax_utils import FlaxModelMixin, get_activation from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) @@ -28,15 +28,6 @@ CACHE_T = 2 -_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish} - - -def get_activation(name: str): - func = _ACTIVATIONS.get(name) - if func is None: - raise ValueError(f"Unknown activation function: {name}") - return func - # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 60f956845..4751e5120 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -14,14 +14,14 @@ limitations under the License. """ -from typing import Tuple, Optional +from typing import Tuple, Optional, Dict, Union, Any import jax import jax.numpy as jnp from flax import nnx from .... import common_types, max_logging from ...modeling_flax_utils import FlaxModelMixin -from ....configuration_utils import ConfigMixin -from ...embeddings_flax import get_1d_rotary_pos_embed +from ....configuration_utils import ConfigMixin, register_to_config +from ...embeddings_flax import get_1d_rotary_pos_embed, NNXFlaxTimesteps, NNXTimestepEmbedding BlockSizes = common_types.BlockSizes @@ -65,7 +65,7 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array: cumulative_sizes = jnp.cumsum(jnp.array(sizes)) split_indices = cumulative_sizes[:-1] freqs_split = jnp.split(self.freqs, split_indices, axis=1) - + freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) @@ -80,6 +80,40 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array: return freqs_final +class WanTimeTextImageEmbedding(nnx.Module): + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + self.timesteps_proj = NNXFlaxTimesteps( + dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0 + ) + self.time_embedder = NNXTimestepEmbedding( + rngs=rngs, in_channels=time_freq_dim, time_embed_dim=dim, + dtype=dtype, weights_dtype=weights_dtype, precision=precision + ) + + def __call__( + self, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None + ): + timestep = self.timesteps_proj(timestep) + temb = self.time_embedder(timestep) + breakpoint() + + + class WanTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin): def __init__( self, @@ -120,25 +154,28 @@ def __init__( class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): - + + @register_to_config def __init__( self, rngs: nnx.Rngs, model_type="t2v", - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2048, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - eps=1e-6, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kn_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, @@ -147,18 +184,62 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", ): - self.path_embedding = nnx.Conv( - in_dim, - dim, + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + #1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nnx.Conv( + in_channels, + inner_dim, + rngs=rngs, kernel_size=patch_size, strides=patch_size, dtype=dtype, param_dtype=weights_dtype, precision=precision, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)), - rngs=rngs, ) - def __call__(self, x): - x = self.path_embedding(x) - return x \ No newline at end of file + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = WanTimeTextImageEmbedding( + rngs=rngs, + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len + ) + + def __call__( + self, + hidden_states: jax.Array, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[jax.Array, Dict[str, jax.Array]]: + batch_size, num_frames, height, width, num_channels = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + + rotary_emb = self.rope(hidden_states) + hidden_states = self.patch_embedding(hidden_states) + hidden_states = jax.lax.collapse(hidden_states, 1, -1) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + #hidden_states = + # Torch shape: ([1, 5120, 21, 45, 80]) + # Jax shape: (1, 21, 45, 80, 5120) so channels is 5120 + + + return hidden_states \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index f4003c7bd..f16043631 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -21,6 +21,7 @@ from flax import nnx from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed +from ..models.embeddings_flax import NNXTimestepEmbedding class WanTransformerTest(unittest.TestCase): def setUp(self): @@ -41,7 +42,19 @@ def test_rotary_pos_embed(self): ) dummy_output = wan_rot_embed(dummy_hidden_states) assert dummy_output.shape == (1, 1, 75600, 64) - # output shape should be torch.Size([1, 1, 75600, 64]) + + def test_nnx_timestep_embedding(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + dummy_sample = jnp.ones((1, 256)) + layer = NNXTimestepEmbedding( + rngs=rngs, + in_channels=256, + time_embed_dim=5120 + ) + dummy_output = layer(dummy_sample) + assert dummy_output.shape == (1, 5120) if __name__ == "__main__": absltest.main() \ No newline at end of file From 08444fd3aba4cec623ac4c16c24727b0db96bf8c Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 9 May 2025 22:52:27 +0000 Subject: [PATCH 27/54] add wan time text embedding layer. --- src/maxdiffusion/models/embeddings_flax.py | 45 +++++++++++++++++++ .../models/modeling_flax_utils.py | 4 +- .../wan/transformers/transformer_wan.py | 34 ++++++++++++-- .../tests/wan_transformer_test.py | 43 ++++++++++++++++-- 4 files changed, 118 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 3744bac79..38a633e29 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -230,6 +230,51 @@ def get_1d_rotary_pos_embed( out = jax.lax.complex(jnp.ones_like(freqs), freqs) return out +class NNXPixArtAlphaTextProjection(nnx.Module): + def __init__( + self, + rngs: nnx.Rngs, + in_features: int, + hidden_size: int, + out_features: int = None, + act_fn: str = "gelu_tanh", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None + ): + if out_features is None: + out_features = hidden_size + + self.linear_1 = nnx.Linear( + rngs=rngs, + in_features=in_features, + out_features=hidden_size, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + self.act_1 = get_activation(act_fn) + + self.linear_2 = nnx.Linear( + rngs=rngs, + in_features=hidden_size, + out_features=out_features, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + def __call__(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states class PixArtAlphaTextProjection(nn.Module): """ diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index f22f8d925..2e300f70e 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -42,8 +42,8 @@ logger = logging.get_logger(__name__) - -_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish} +# gelu and gelu_tanh both use approximate=True by default +_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "gelu_tanh" : jax.nn.gelu, "mish": jax.nn.mish} def get_activation(name: str): func = _ACTIVATIONS.get(name) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 4751e5120..861a8366d 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -19,9 +19,14 @@ import jax.numpy as jnp from flax import nnx from .... import common_types, max_logging -from ...modeling_flax_utils import FlaxModelMixin +from ...modeling_flax_utils import FlaxModelMixin, get_activation from ....configuration_utils import ConfigMixin, register_to_config -from ...embeddings_flax import get_1d_rotary_pos_embed, NNXFlaxTimesteps, NNXTimestepEmbedding +from ...embeddings_flax import ( + get_1d_rotary_pos_embed, + NNXFlaxTimesteps, + NNXTimestepEmbedding, + NNXPixArtAlphaTextProjection +) BlockSizes = common_types.BlockSizes @@ -101,6 +106,23 @@ def __init__( rngs=rngs, in_channels=time_freq_dim, time_embed_dim=dim, dtype=dtype, weights_dtype=weights_dtype, precision=precision ) + self.act_fn = get_activation("silu") + self.time_proj = nnx.Linear( + rngs=rngs, + in_features=dim, + out_features=time_proj_dim, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + self.text_embedder = NNXPixArtAlphaTextProjection( + rngs=rngs, + in_features=text_embed_dim, + hidden_size=dim, + act_fn="gelu_tanh", + ) def __call__( self, @@ -110,7 +132,13 @@ def __call__( ): timestep = self.timesteps_proj(timestep) temb = self.time_embedder(timestep) - breakpoint() + + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + raise NotImplementedError("currently img2vid is not supported") + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index f16043631..a6d4ed7db 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -20,8 +20,8 @@ from absl.testing import absltest from flax import nnx -from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed -from ..models.embeddings_flax import NNXTimestepEmbedding +from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding +from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection class WanTransformerTest(unittest.TestCase): def setUp(self): @@ -42,7 +42,19 @@ def test_rotary_pos_embed(self): ) dummy_output = wan_rot_embed(dummy_hidden_states) assert dummy_output.shape == (1, 1, 75600, 64) - + + def test_nnx_pixart_alpha_text_projection(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dummy_caption = jnp.ones((1, 512, 4096)) + layer = NNXPixArtAlphaTextProjection( + rngs=rngs, + in_features=4096, + hidden_size=5120 + ) + dummy_output = layer(dummy_caption) + dummy_output.shape == (1, 512, 5120) + def test_nnx_timestep_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -56,5 +68,30 @@ def test_nnx_timestep_embedding(self): dummy_output = layer(dummy_sample) assert dummy_output.shape == (1, 5120) + def test_wan_time_text_embedding(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size = 1 + dim=5120 + time_freq_dim=256 + time_proj_dim=30720 + text_embed_dim=4096 + layer = WanTimeTextImageEmbedding( + rngs=rngs, + dim=dim, + time_freq_dim=time_freq_dim, + time_proj_dim=time_proj_dim, + text_embed_dim=text_embed_dim + ) + + dummy_timestep = jnp.ones(batch_size) + + encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) + dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(dummy_timestep, dummy_encoder_hidden_states) + assert temb.shape == (batch_size, dim) + assert timestep_proj.shape == (batch_size, time_proj_dim) + assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + if __name__ == "__main__": absltest.main() \ No newline at end of file From 2499b2defe3ada68dd876269a89a231ef7c0f956 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 9 May 2025 23:26:58 +0000 Subject: [PATCH 28/54] add fp32 layer norm --- src/maxdiffusion/models/normalization_flax.py | 18 ++++++++++++++++++ src/maxdiffusion/tests/wan_transformer_test.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index ea3b970d8..8c8463e62 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp import flax.linen as nn +from flax import nnx class AdaLayerNormContinuous(nn.Module): @@ -147,3 +148,20 @@ def __call__(self, x, emb): else: raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") return x, gate_msa + +class FP32LayerNorm(nnx.Module): + def __init__(self, rngs: nnx.Rngs, dim: int, eps : float, elementwise_affine: bool): + self.layer_norm = nnx.LayerNorm( + rngs=rngs, + num_features=dim, + epsilon=eps, + use_bias=elementwise_affine, + use_scale=elementwise_affine, + param_dtype=jnp.float32, + dtype=jnp.float32 + ) + def __call__(self, inputs: jax.Array) -> jax.Array: + origin_dtype = inputs.dtype + return self.layer_norm( + inputs.astype(dtype=jnp.float32) + ).astype(dtype=origin_dtype) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index a6d4ed7db..98892400a 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -22,6 +22,7 @@ from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection +from ..models.normalization_flax import FP32LayerNorm class WanTransformerTest(unittest.TestCase): def setUp(self): @@ -68,6 +69,21 @@ def test_nnx_timestep_embedding(self): dummy_output = layer(dummy_sample) assert dummy_output.shape == (1, 5120) + def test_fp32_layer_norm(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size = 1 + dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) + # expected same output shape with same dtype + layer = FP32LayerNorm( + rngs=rngs, + dim=5120, + eps=1e-6, + elementwise_affine=False + ) + dummy_output = layer(dummy_hidden_states) + assert dummy_output.shape == dummy_hidden_states.shape + def test_wan_time_text_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) From b9b246525d95bdf9b221373f3e2f9e19ea74752a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 13 May 2025 00:41:52 +0000 Subject: [PATCH 29/54] wip - attention for wan. --- src/maxdiffusion/configs/base_wan_14b.yml | 258 +++++ src/maxdiffusion/models/attention_flax.py | 920 +++++++++++------- .../wan/transformers/transformer_wan.py | 41 +- 3 files changed, 867 insertions(+), 352 deletions(-) create mode 100644 src/maxdiffusion/configs/base_wan_14b.yml diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml new file mode 100644 index 000000000..d7a802225 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -0,0 +1,258 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te + +flash_block_sizes: { + "block_q" : 512, + "block_kv_compute" : 512, + "block_kv" : 512, + "block_q_dkv" : 512, + "block_kv_dkv" : 512, + "block_kv_dkv_compute" : 512, + "block_q_dq" : 512, + "block_kv_dq" : 512 +} +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A magical castle in the middle of a forest, artistic drawing" +prompt_2: "A magical castle in the middle of a forest, artistic drawing" +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 3.5 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 50 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 2f8946056..bcd3c723c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -14,8 +14,9 @@ import functools import math -from typing import Optional +from typing import Optional, Callable, Tuple import flax.linen as nn +from flax import nnx import jax import jax.numpy as jnp from jax.experimental import shard_map @@ -42,275 +43,295 @@ Quant = quantizations.AqtQuantization -Quant = quantizations.AqtQuantization - - def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() +def _check_attention_inputs(query: Array, key: Array, value: Array) -> None: + """Check attention inputs.""" + + assert key.ndim == value.ndim, "k, v must have same rank." + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." + assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." + assert key.shape[-3] == value.shape[-3], "k, v lengths must match." + assert query.shape[-1] == key.shape[-1], "q, k depths must match." + +def _reshape_data_from_cudnn_flash(tensor): + # reshapes from [b, s, h, d] back to [b, s, h * d] + return tensor.reshape(tensor.shape[0], tensor.shape[1], -1) + +def _reshape_data_for_cudnn_flash(tensor, heads): + # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + return tensor + +def _reshape_batch_dim_to_heads(tensor, heads): + batch_size, seq_len, dim = tensor.shape + head_size = heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + +def _reshape_heads_to_batch_dim(tensor, heads): + batch_size, seq_len, dim = tensor.shape + head_size = heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + +def _reshape_heads_to_head_dim(tensor): + # takes a tensor of shape [b, h, s, d] and reshapes to [b, s, h * d] + # This is used to transform the output of flash attention back into the format of other attention outputs + b, h, s, d = tensor.shape + tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) + return jnp.reshape(tensor, (b, -1, h * d)) + +def _unflatten_heads(tensor, heads): + # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + # Transpose to ('batch', 'heads', 'length', 'kv') + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + return tensor + +def _reshape_data_for_flash(tensor, heads, flash_block_size): + """ + Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + """ + tensor = _unflatten_heads(tensor, heads) + kv_size = tensor.shape[-1] + if kv_size < 128: + npad = ((0, 0), (0, 0), (0, 0), (0, 128 - kv_size)) + tensor = jnp.pad(tensor, npad) + seq_len = tensor.shape[2] + rem = seq_len % flash_block_size + if rem != 0: + mul = seq_len // flash_block_size + npad = ((0, 0), (0, 0), (0, (mul + 1)*flash_block_size - seq_len), (0, 0)) + tensor = jnp.pad(tensor, npad) + return tensor, kv_size, seq_len + +def _tpu_flash_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + flash_axis_names: AxisNames, + flash_block_sizes: BlockSizes) -> jax.Array: + """TPU Flash Attention""" + + if flash_block_sizes: + block_sizes = flash_block_sizes + else: + block_sizes = splash_attention_kernel.BlockSizes( + block_q=min(512, query.shape[2]), + block_kv_compute=min(512, key.shape[2]), + block_kv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_kv_dkv=min(512, key.shape[2]), + block_kv_dkv_compute=min(512, query.shape[2]), + block_q_dq=min(512, query.shape[2]), + block_kv_dq=min(512, query.shape[2]), + ) -class AttentionOp(nn.Module): - mesh: Mesh - attention_kernel: str - scale: int - heads: int - dim_head: int - use_memory_efficient_attention: bool = False - split_head_dim: bool = False - float32_qk_product: bool = True - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - dtype: DType = jnp.float32 - quant: Quant = None - - def setup(self): - if self.attention_kernel == "cudnn_flash_te": - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - - self.dpa_layer = DotProductAttention( - head_dim=self.dim_head, - num_attention_heads=self.heads, - num_gqa_groups=self.heads, - attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' - attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - # attention_dropout=self.dropout_rate, - dropout_rng_name="aqt", - dtype=self.dtype, - # float32_logits=self.float32_logits, - qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=self.scale, - transpose_batch_sequence=False, - ) - - def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: - """Check attention inputs.""" - - assert key.ndim == value.ndim, "k, v must have same rank." - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." - assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." - assert key.shape[-3] == value.shape[-3], "k, v lengths must match." - assert query.shape[-1] == key.shape[-1], "q, k depths must match." - - def apply_attention(self, query: Array, key: Array, value: Array): - """Routes to different attention kernels.""" - self.check_attention_inputs(query, key, value) - - if self.attention_kernel == "flash": - can_use_flash_attention = ( - query.shape[1] >= self.flash_min_seq_length - and key.shape[1] >= self.flash_min_seq_length - and value.shape[1] >= self.flash_min_seq_length - ) - else: - can_use_flash_attention = True - - if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention: - return self.apply_attention_dot(query, key, value) - elif self.attention_kernel == "flash": - return self.tpu_flash_attention(query, key * self.scale, value) - elif self.attention_kernel == "cudnn_flash_te": - return self.cudnn_flash_attention(query, key, value) - else: - raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") - - def tpu_flash_attention(self, query: jax.Array, key: jax.Array, value: jax.Array) -> jax.Array: - """TPU Flash Attention""" - - query, kv_size = self.reshape_data_for_flash(query) - key, _ = self.reshape_data_for_flash(key) - value, _ = self.reshape_data_for_flash(value) - - axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) - - @functools.partial( - shard_map.shard_map, - mesh=self.mesh, - in_specs=( - axis_names, - axis_names, - axis_names, - ), - out_specs=axis_names, - check_rep=False, + query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q) + key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute) + value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute) + # query_seq_len = query.shape[2] + # query_rem = query_seq_len % block_sizes.block_q + # if query_rem != 0: + # query_mul = query_seq_len // block_sizes.block_q + # npad = ((0, 0), (0, 0), (0, (query_mul + 1)*block_sizes.block_q - query.shape[2]), (0, 0)) + # query = jnp.pad(query, npad) + # key = jnp.pad(key, npad) + # value = jnp.pad(value, npad) + # breakpoint() + axis_names = nn.logical_to_mesh_axes(flash_axis_names) + + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=( + axis_names, + axis_names, + axis_names, + ), + out_specs=axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value): + masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])] + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes ) - def wrap_flash_attention(query, key, value): - if self.flash_block_sizes: - block_sizes = self.flash_block_sizes - else: - block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(512, query.shape[2]), - block_kv_compute=min(512, key.shape[2]), - block_kv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_kv_dkv=min(512, key.shape[2]), - block_kv_dkv_compute=min(512, query.shape[2]), - block_q_dq=min(512, query.shape[2]), - block_kv_dq=min(512, query.shape[2]), - ) - masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])] - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes - ) - return jax.vmap(splash_kernel)(query, key, value) - - devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] - # This warning might show up when doing model eval for example, when calculating model flops - # and that is expected. - if not (query.shape[0] / devices_in_data_fsdp).is_integer(): - max_logging.log( - "Warning, batch dimension should be shardable among the devices in data and fsdp" - f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" - ) - x = wrap_flash_attention(query, key, value) - x = x[:, :, :, :kv_size] - x = self.reshape_heads_to_head_dim(x) - - return x - - def cudnn_flash_attention( - self, - query: Array, - key: Array, - value: Array, - ) -> Array: - """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA - 2. Supports head_dim till 128; head_dim=256 support will be added soon - """ - # These imports are only meant to work in a GPU build. - # copied from tpu_flash_attention - query = self.reshape_data_for_cudnn_flash(query) - key = self.reshape_data_for_cudnn_flash(key) - value = self.reshape_data_for_cudnn_flash(value) - - cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) - axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) - - query = nn.with_logical_constraint(query, axis_names) - key = nn.with_logical_constraint(key, axis_names) - value = nn.with_logical_constraint(value, axis_names) - - @functools.partial( - shard_map.shard_map, - mesh=self.mesh, - in_specs=(axis_names, axis_names, axis_names), - out_specs=axis_names, - check_rep=False, + return jax.vmap(splash_kernel)(query, key, value) + + devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] + # This warning might show up when doing model eval for example, when calculating model flops + # and that is expected. + if not (query.shape[0] / devices_in_data_fsdp).is_integer(): + max_logging.log( + "Warning, batch dimension should be shardable among the devices in data and fsdp" + f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) - def wrap_flash_attention(query, key, value): - return jax.vmap(self.dpa_layer)(query, key, value, mask=None) - - out = wrap_flash_attention(query, key, value) - return self.reshape_data_from_cudnn_flash(out) - - def apply_attention_dot(self, query: Array, key: Array, value: Array): - """Apply Attention.""" - if self.split_head_dim: - b = key.shape[0] - query_states = jnp.reshape(query, (b, -1, self.heads, self.dim_head)) - key_states = jnp.reshape(key, (b, -1, self.heads, self.dim_head)) - value_states = jnp.reshape(value, (b, -1, self.heads, self.dim_head)) - else: - query_states = self.reshape_heads_to_batch_dim(query) - key_states = self.reshape_heads_to_batch_dim(key) - value_states = self.reshape_heads_to_batch_dim(value) - - if self.float32_qk_product: - query_states = query_states.astype(jnp.float32) - key_states = key_states.astype(jnp.float32) - - if self.use_memory_efficient_attention: - query_states = query_states.transpose(1, 0, 2) - key_states = key_states.transpose(1, 0, 2) - value_states = value_states.transpose(1, 0, 2) - - # this if statement create a chunk size for each layer of the unet - # the chunk size is equal to the query_length dimension of the deepest layer of the unet - - flatten_latent_dim = query_states.shape[-3] - if flatten_latent_dim % 64 == 0: - query_chunk_size = int(flatten_latent_dim / 64) - elif flatten_latent_dim % 16 == 0: - query_chunk_size = int(flatten_latent_dim / 16) - elif flatten_latent_dim % 4 == 0: - query_chunk_size = int(flatten_latent_dim / 4) - else: - query_chunk_size = int(flatten_latent_dim) - - hidden_states = jax_memory_efficient_attention( - query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 - ) - - hidden_states = hidden_states.transpose(1, 0, 2) + x = wrap_flash_attention(query, key, value) + x = x[:, :, :query_seq_len, :kv_size] + x = _reshape_heads_to_head_dim(x) + + return x + +def _apply_attention_dot( + query: Array, + key: Array, + value: Array, + dtype: jnp.dtype, + heads: int, + dim_head: int, + scale: float, + split_head_dim: bool, + float32_qk_product: bool, + use_memory_efficient_attention: bool +): + """Apply Attention.""" + if split_head_dim: + b = key.shape[0] + query_states = jnp.reshape(query, (b, -1, heads, dim_head)) + key_states = jnp.reshape(key, (b, -1, heads, dim_head)) + value_states = jnp.reshape(value, (b, -1, heads, dim_head)) + else: + query_states = _reshape_heads_to_batch_dim(query, heads) + key_states = _reshape_heads_to_batch_dim(key, heads) + value_states = _reshape_heads_to_batch_dim(value, heads) + + if float32_qk_product: + query_states = query_states.astype(jnp.float32) + key_states = key_states.astype(jnp.float32) + + if use_memory_efficient_attention: + query_states = query_states.transpose(1, 0, 2) + key_states = key_states.transpose(1, 0, 2) + value_states = value_states.transpose(1, 0, 2) + + # this if statement create a chunk size for each layer of the unet + # the chunk size is equal to the query_length dimension of the deepest layer of the unet + + flatten_latent_dim = query_states.shape[-3] + if flatten_latent_dim % 64 == 0: + query_chunk_size = int(flatten_latent_dim / 64) + elif flatten_latent_dim % 16 == 0: + query_chunk_size = int(flatten_latent_dim / 16) + elif flatten_latent_dim % 4 == 0: + query_chunk_size = int(flatten_latent_dim / 4) else: - if self.split_head_dim: - attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) - else: - attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) - - attention_scores = attention_scores * self.scale - attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2) + query_chunk_size = int(flatten_latent_dim) - attention_probs = attention_probs.astype(self.dtype) + hidden_states = jax_memory_efficient_attention( + query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + ) - # attend to values - if self.split_head_dim: - hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) - b = hidden_states.shape[0] - hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head)) - else: - hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = hidden_states.transpose(1, 0, 2) + else: + if split_head_dim: + attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) + else: + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) - return hidden_states + attention_scores = attention_scores * scale + attention_probs = nn.softmax(attention_scores, axis=-1 if split_head_dim else 2) - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def reshape_data_for_cudnn_flash(self, tensor): - # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) - batch, seq, heads_and_dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads) - return tensor - - def reshape_data_from_cudnn_flash(self, tensor): - # reshapes from [b, s, h, d] back to [b, s, h * d] - return tensor.reshape(tensor.shape[0], tensor.shape[1], -1) - - def reshape_data_for_flash(self, tensor): - # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) - batch, seq, heads_and_dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads) - # Transpose to ('batch', 'heads', 'length', 'kv') - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - kv_size = tensor.shape[-1] - if kv_size < 128: - npad = ((0, 0), (0, 0), (0, 0), (0, 128 - kv_size)) - tensor = jnp.pad(tensor, npad) - return tensor, kv_size - - def reshape_heads_to_head_dim(self, tensor): - # takes a tensor of shape [b, h, s, d] and reshapes to [b, s, h * d] - # This is used to transform the output of flash attention back into the format of other attention outputs - b, h, s, d = tensor.shape - tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) - return jnp.reshape(tensor, (b, -1, h * d)) + attention_probs = attention_probs.astype(dtype) + # attend to values + if split_head_dim: + hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) + b = hidden_states.shape[0] + hidden_states = jnp.reshape(hidden_states, (b, -1, heads * dim_head)) + else: + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + hidden_states = _reshape_batch_dim_to_heads(hidden_states, heads) + + return hidden_states + +def _cudnn_flash_attention( + query: Array, + key: Array, + value: Array, + heads: int, + mesh: Mesh, + dpa_layer: Callable +) -> Array: + """CUDNN Flash Attention with Transformer Engine. + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon + """ + # These imports are only meant to work in a GPU build. + # copied from tpu_flash_attention + query = _reshape_data_for_cudnn_flash(query, heads) + key = _reshape_data_for_cudnn_flash(key, heads) + value = _reshape_data_for_cudnn_flash(value, heads) + + cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) + axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) + + query = nn.with_logical_constraint(query, axis_names) + key = nn.with_logical_constraint(key, axis_names) + value = nn.with_logical_constraint(value, axis_names) + + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=(axis_names, axis_names, axis_names), + out_specs=axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value): + return jax.vmap(dpa_layer)(query, key, value, mask=None) + + out = wrap_flash_attention(query, key, value) + return _reshape_data_from_cudnn_flash(out) + +def _apply_attention( + query: Array, + key: Array, + value: Array, + heads: int, + dim_head: int, + split_head_dim: bool, + float32_qk_product: bool, + attention_kernel: str, + flash_min_seq_length: int, + use_memory_efficient_attention: bool, + scale: float, + dtype: jnp.dtype, + mesh: Mesh, + flash_axis_names: AxisNames, + flash_block_sizes: BlockSizes, + dpa_layer: Callable + ): + """Routes to different attention kernels.""" + _check_attention_inputs(query, key, value) + + if attention_kernel == "flash": + can_use_flash_attention = ( + query.shape[1] >= flash_min_seq_length + and key.shape[1] >= flash_min_seq_length + and value.shape[1] >= flash_min_seq_length + ) + else: + can_use_flash_attention = True + + if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: + return _apply_attention_dot(query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention) + elif attention_kernel == "flash": + return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes) + elif attention_kernel == "cudnn_flash_te": + return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) + else: + raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" @@ -416,132 +437,333 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) -class FlaxWanAttention(nn.Module): - query_dim: int - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 +class NNXAttentionOp(nnx.Module): + def __init__( + self, + mesh: Mesh, + attention_kernel: str, + scale: int, + heads: int, + dim_head: int, + use_memory_efficient_attention: bool = False, + split_head_dim: bool = False, + float32_qk_product: bool = True, + flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + dtype: DType = jnp.float32, + quant: Quant = None, + ): + self.dpa_layer = None + if attention_kernel == "cudnn_flash_te": + raise NotImplementedError("Wan 2.1 has not been tested with cudnn_flash_te") + + self.mesh = mesh + self.scale = scale + self.heads = heads + self.dim_head = dim_head + self.attention_kernel = attention_kernel + self.use_memory_efficient_attention = use_memory_efficient_attention + self.split_head_dim = split_head_dim + self.float32_qk_product=float32_qk_product + self.flash_axis_names=flash_axis_names + self.flash_min_seq_length=flash_min_seq_length + self.flash_block_sizes=flash_block_sizes + self.dtype=dtype + self.quant=quant + + def apply_attention(self, query: Array, key: Array, value: Array): + return _apply_attention( + query=query, + key=key, + value=value, + heads=self.heads, + dim_head=self.dim_head, + split_head_dim=self.split_head_dim, + float32_qk_product=self.float32_qk_product, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + scale=self.scale, + dtype=self.dtype, + mesh=self.mesh, + flash_axis_names=self.flash_axis_names, + flash_block_sizes=self.flash_block_sizes, + dpa_layer=self.dpa_layer + ) + +class AttentionOp(nn.Module): + mesh: Mesh + attention_kernel: str + scale: int + heads: int + dim_head: int use_memory_efficient_attention: bool = False split_head_dim: bool = False - attention_kernel: str = "dot_product" + float32_qk_product: bool = True + flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) - precision: jax.lax.Precision = None - qkv_bias: bool = False + dtype: DType = jnp.float32 + quant: Quant = None def setup(self): - if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: - raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") - inner_dim = self.dim_head * self.heads - scale = self.dim_head**-0.5 + if self.attention_kernel == "cudnn_flash_te": + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - self.attention_op = AttentionOp( - mesh=self.mesh, - attention_kernel=self.attention_kernel, - scale=scale, - heads=self.heads, - dim_head=self.dim_head, - flash_min_seq_length=self.flash_min_seq_length, - use_memory_efficient_attention=self.use_memory_efficient_attention, - split_head_dim=self.split_head_dim, - flash_block_sizes=self.flash_block_sizes, - dtype=self.dtype, - float32_qk_product=False, + self.dpa_layer = DotProductAttention( + head_dim=self.dim_head, + num_attention_heads=self.heads, + num_gqa_groups=self.heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + # float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=self.scale, + transpose_batch_sequence=False, + ) + + def apply_attention(self, query: Array, key: Array, value: Array): + return _apply_attention( + query=query, + key=key, + value=value, + heads=self.heads, + dim_head=self.dim_head, + split_head_dim=self.split_head_dim, + float32_qk_product=self.float32_qk_product, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + scale=self.scale, + dtype=self.dtype, + mesh=self.mesh, + flash_axis_names=self.flash_axis_names, + flash_block_sizes=self.flash_block_sizes, + dpa_layer=self.dpa_layer ) - kernel_axes = ("embed", "heads") - qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) - qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "heads")) +class FlaxWanAttention(nnx.Module): + def __init__( + self, + rngs: nnx.Rngs, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + eps: float = 1e-6, + qk_norm: str = "rms_norm_across_heads", + use_memory_efficient_attention: bool = False, + split_head_dim: bool = False, + attention_kernel: str = "flash", + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + precision: jax.lax.Precision = None, + qkv_bias: bool = False, + quant: Quant = None, + ): + # TODO - Params from pytorch implementation + # to set for the creation of this. + # bias is True + # upcast_attention - False + # upcast_softmax - False + # cross_attention_norm - None + # cross_attention_norm_num_groups - 32 + # qk_norm - rms_norm_across_heads + # added_kv_proj_dim + # norm_num_groups: Optional[int] = None, + # spatial_norm_dim: Optional[int] = None, + # out_bias: bool = True, + # scale_qk: bool = True, + # only_cross_attention - False + # eps - 1e-06 + # rescale_output_factor: float = 1.0, + # residual_connection: bool = False, + # _from_deprecated_attn_block: bool = False, + # processor: Optional["AttnProcessor"] = WanAttnProcessor2_0 + # out_dim: int = None, + # out_context_dim: int = None, + # context_pre_only=None, + # pre_only=False, + # elementwise_affine: bool = True, + # is_causal: bool = False, + + if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + self.dim_head = dim_head + self.heads = heads + self.inner_dim = dim_head * heads + scale = dim_head**-0.5 + self.qk_norm = qk_norm + self.query_axis_names = query_axis_names + self.key_axis_names = key_axis_names + self.value_axis_names = value_axis_names + self.out_axis_names = out_axis_names + + self.attention_op = NNXAttentionOp( + mesh=mesh, + attention_kernel=attention_kernel, + scale=scale, + heads=heads, + dim_head=dim_head, + use_memory_efficient_attention=use_memory_efficient_attention, + split_head_dim=split_head_dim, + float32_qk_product=False, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + dtype=dtype, + quant=quant + ) - self.query = nn.Dense( - inner_dim, + kernel_axes = ("embed", "heads") + qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) + + self.query = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_q", - precision=self.precision, + use_bias=qkv_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.key = nn.Dense( - inner_dim, + self.key = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_k", - precision=self.precision, + use_bias=qkv_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.value = nn.Dense( - inner_dim, + self.value = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_v", - precision=self.precision, + use_bias=qkv_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.query_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, - ) - self.key_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, + self.proj_attn = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), + use_bias=qkv_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.proj_attn = nn.Dense( - self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("heads", "embed")), - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_out_0", - precision=self.precision, - ) - self.dropout_layer = nn.Dropout(rate=self.dropout) + self.query_norm = None + self.key_norm = None + if qk_norm is not None: + self.query_norm = nnx.RMSNorm( + num_features=self.inner_dim, + rngs=rngs, + epsilon=eps, + dtype=dtype, + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("heads", )), + param_dtype=weights_dtype + ) - def call( - self, - hidden_states: Array, - encoder_hidden_states: Optional[Array], - rotary_emb: Optional[Array], - deterministic: bool = True, - ): - encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + self.key_norm = nnx.RMSNorm( + num_features=self.inner_dim, + rngs=rngs, + dtype=dtype, + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("heads", )), + param_dtype=weights_dtype + ) + + def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]: + dtype = xq.dtype + reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) + reshape_xk = xq.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) + xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + + freqs_cis = freqs_cis[None, None, ...] + xq_out_complex = xq_ * freqs_cis + xk_out_complex = xk_ * freqs_cis + + xq_out = jnp.stack([jnp.real(xq_out_complex), jnp.imag(xq_out_complex)], axis=-1).reshape(xq.shape).astype(dtype) + xk_out = jnp.stack([jnp.real(xk_out_complex), jnp.imag(xk_out_complex)], axis=-1).reshape(xk.shape).astype(dtype) + + return xq_out, xk_out + + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array, + rotary_emb: Optional[jax.Array] = None + ) -> jax.Array: + batch_size = hidden_states.shape[0] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states query_proj = self.query(hidden_states) key_proj = self.key(encoder_hidden_states) value_proj = self.value(encoder_hidden_states) - query_proj = self.query_norm(query_proj) - key_proj = self.key_norm(key_proj) - - if rotary_emb: - query_proj, key_proj = self.apply_rope(query_proj, key_proj, rotary_emb) - query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) - hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + if self.qk_norm: + query_proj = self.query_norm(query_proj) + key_proj = self.key_norm(key_proj) + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) + if rotary_emb is not None: + query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + #breakpoint() + query_proj = _reshape_heads_to_head_dim(query_proj) + key_proj = _reshape_heads_to_head_dim(key_proj) + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + breakpoint() + - hidden_states = self.proj_attn(hidden_states) - hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) - return self.dropout_layer(hidden_states, deterministic=deterministic) + def setup(self): + if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + inner_dim = self.dim_head * self.heads + scale = self.dim_head**-0.5 + + self.attention_op = NNXAttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=scale, + heads=self.heads, + dim_head=self.dim_head, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype, + float32_qk_product=False, + ) + class FlaxFluxAttention(nn.Module): query_dim: int heads: int = 8 diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 861a8366d..07774303b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -27,6 +27,7 @@ NNXTimestepEmbedding, NNXPixArtAlphaTextProjection ) +from ...normalization_flax import FP32LayerNorm BlockSizes = common_types.BlockSizes @@ -181,6 +182,29 @@ def __init__( ) +class WanTransformerBlock(nnx.Module): + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None + ): + self.norm1 = FP32LayerNorm( + dim=dim, + eps=eps, + elementwise_affine=False + ) + + def __call__(self): + pass + + + class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @register_to_config @@ -242,6 +266,13 @@ def __init__( pos_embed_seq_len=pos_embed_seq_len ) + # 3. Transformer blocks + blocks = [] + for _ in range(num_layers): + block = WanTransformerBlock() + blocks.append(block) + self.blocks = blocks + def __call__( self, hidden_states: jax.Array, @@ -265,9 +296,13 @@ def __call__( temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - #hidden_states = - # Torch shape: ([1, 5120, 21, 45, 80]) - # Jax shape: (1, 21, 45, 80, 5120) so channels is 5120 + timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) + + if encoder_hidden_states_image is not None: + raise NotImplementedError("img2vid is not yet implemented.") + + # for block in self.blocks: + return hidden_states \ No newline at end of file From 1abc00cc756ecb26066d038cd8e4e82496342f75 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 13 May 2025 16:14:44 +0000 Subject: [PATCH 30/54] wrap up attention. --- src/maxdiffusion/models/attention_flax.py | 88 ++++++------------- .../tests/wan_transformer_test.py | 80 +++++++++++++++++ 2 files changed, 106 insertions(+), 62 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index bcd3c723c..505d39e41 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -101,16 +101,28 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size): Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. """ tensor = _unflatten_heads(tensor, heads) + + # pad head_dim to 128 if less than that. kv_size = tensor.shape[-1] + head_dim_pad = 0 if kv_size < 128: - npad = ((0, 0), (0, 0), (0, 0), (0, 128 - kv_size)) - tensor = jnp.pad(tensor, npad) + head_dim_pad = 128 - kv_size + + # pad seq_len to a multiple of flash_block_size if needed. seq_len = tensor.shape[2] + # remainder rem = seq_len % flash_block_size + seq_len_pad = 0 if rem != 0: + # multiplier mul = seq_len // flash_block_size - npad = ((0, 0), (0, 0), (0, (mul + 1)*flash_block_size - seq_len), (0, 0)) + # pad to the closest multiplier of flash_block_size + seq_len_pad = (mul + 1) * flash_block_size - seq_len + + if kv_size < 128 or rem != 0: + npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) tensor = jnp.pad(tensor, npad) + return tensor, kv_size, seq_len def _tpu_flash_attention( @@ -140,15 +152,7 @@ def _tpu_flash_attention( query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q) key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute) value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute) - # query_seq_len = query.shape[2] - # query_rem = query_seq_len % block_sizes.block_q - # if query_rem != 0: - # query_mul = query_seq_len // block_sizes.block_q - # npad = ((0, 0), (0, 0), (0, (query_mul + 1)*block_sizes.block_q - query.shape[2]), (0, 0)) - # query = jnp.pad(query, npad) - # key = jnp.pad(key, npad) - # value = jnp.pad(value, npad) - # breakpoint() + axis_names = nn.logical_to_mesh_axes(flash_axis_names) @functools.partial( @@ -456,7 +460,7 @@ def __init__( ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": - raise NotImplementedError("Wan 2.1 has not been tested with cudnn_flash_te") + raise NotImplementedError(f"{self} has not been tested with {attention_kernel}") self.mesh = mesh self.scale = scale @@ -574,34 +578,13 @@ def __init__( qkv_bias: bool = False, quant: Quant = None, ): - # TODO - Params from pytorch implementation - # to set for the creation of this. - # bias is True - # upcast_attention - False - # upcast_softmax - False - # cross_attention_norm - None - # cross_attention_norm_num_groups - 32 - # qk_norm - rms_norm_across_heads - # added_kv_proj_dim - # norm_num_groups: Optional[int] = None, - # spatial_norm_dim: Optional[int] = None, - # out_bias: bool = True, - # scale_qk: bool = True, - # only_cross_attention - False - # eps - 1e-06 - # rescale_output_factor: float = 1.0, - # residual_connection: bool = False, - # _from_deprecated_attn_block: bool = False, - # processor: Optional["AttnProcessor"] = WanAttnProcessor2_0 - # out_dim: int = None, - # out_context_dim: int = None, - # context_pre_only=None, - # pre_only=False, - # elementwise_affine: bool = True, - # is_causal: bool = False, + + if attention_kernel == "cudnn_flash_te" or attention_kernel == "dot_product": + raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + self.dim_head = dim_head self.heads = heads self.inner_dim = dim_head * heads @@ -717,7 +700,8 @@ def __call__( encoder_hidden_states: jax.Array, rotary_emb: Optional[jax.Array] = None ) -> jax.Array: - batch_size = hidden_states.shape[0] + dtype = hidden_states.dtype + # batch_size = hidden_states.shape[0] if encoder_hidden_states is None: encoder_hidden_states = hidden_states query_proj = self.query(hidden_states) @@ -735,35 +719,15 @@ def __call__( key_proj = _unflatten_heads(key_proj, self.heads) if rotary_emb is not None: query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - #breakpoint() query_proj = _reshape_heads_to_head_dim(query_proj) key_proj = _reshape_heads_to_head_dim(key_proj) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - breakpoint() - + attn_output = attn_output.astype(dtype=dtype) + hidden_states = self.proj_attn(hidden_states) + return hidden_states - def setup(self): - if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: - raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") - inner_dim = self.dim_head * self.heads - scale = self.dim_head**-0.5 - - self.attention_op = NNXAttentionOp( - mesh=self.mesh, - attention_kernel=self.attention_kernel, - scale=scale, - heads=self.heads, - dim_head=self.dim_head, - flash_min_seq_length=self.flash_min_seq_length, - use_memory_efficient_attention=self.use_memory_efficient_attention, - split_head_dim=self.split_head_dim, - flash_block_sizes=self.flash_block_sizes, - dtype=self.dtype, - float32_qk_product=False, - ) - class FlaxFluxAttention(nn.Module): query_dim: int heads: int = 8 diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 98892400a..0945c6416 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -14,15 +14,25 @@ limitations under the License. """ +import os import jax import jax.numpy as jnp import unittest from absl.testing import absltest from flax import nnx +from jax.sharding import Mesh +from .. import pyconfig +from ..max_utils import ( + create_device_mesh, + get_flash_block_sizes +) from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection from ..models.normalization_flax import FP32LayerNorm +from ..models.attention_flax import FlaxWanAttention + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) class WanTransformerTest(unittest.TestCase): def setUp(self): @@ -108,6 +118,76 @@ def test_wan_time_text_embedding(self): assert temb.shape == (batch_size, dim) assert timestep_proj.shape == (batch_size, time_proj_dim) assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + + def test_wan_attention(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed( + attention_head_dim=128, + patch_size=[1, 2, 2], + max_seq_len=1024 + ) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + dummy_hidden_states_shape = (batch_size, 75600, query_dim) + + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape + + # dot product + try: + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + except NotImplementedError as e: + pass + + if __name__ == "__main__": absltest.main() \ No newline at end of file From 4c00085621308297a0726cf32e8e917a9dfd6cd5 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 13 May 2025 20:34:19 +0000 Subject: [PATCH 31/54] add transformer block --- src/maxdiffusion/models/attention_flax.py | 15 +- .../wan/transformers/transformer_wan.py | 202 +++++++++++++++++- .../tests/wan_transformer_test.py | 71 +++++- 3 files changed, 274 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 505d39e41..342ee22d2 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -697,11 +697,11 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup def __call__( self, hidden_states: jax.Array, - encoder_hidden_states: jax.Array, + encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None ) -> jax.Array: + dtype = hidden_states.dtype - # batch_size = hidden_states.shape[0] if encoder_hidden_states is None: encoder_hidden_states = hidden_states query_proj = self.query(hidden_states) @@ -715,12 +715,14 @@ def __call__( if self.qk_norm: query_proj = self.query_norm(query_proj) key_proj = self.key_norm(key_proj) - query_proj = _unflatten_heads(query_proj, self.heads) - key_proj = _unflatten_heads(key_proj, self.heads) + if rotary_emb is not None: + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - query_proj = _reshape_heads_to_head_dim(query_proj) - key_proj = _reshape_heads_to_head_dim(key_proj) + query_proj = _reshape_heads_to_head_dim(query_proj) + key_proj = _reshape_heads_to_head_dim(key_proj) + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = attn_output.astype(dtype=dtype) @@ -1309,7 +1311,6 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) - class FlaxFeedForward(nn.Module): r""" Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 07774303b..041b76916 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -28,6 +28,7 @@ NNXPixArtAlphaTextProjection ) from ...normalization_flax import FP32LayerNorm +from ...attention_flax import FlaxWanAttention BlockSizes = common_types.BlockSizes @@ -181,6 +182,89 @@ def __init__( rope_max_seq_len ) +class ApproximateGELU(nnx.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + """ + def __init__( + self, + rngs: nnx.Rngs, + dim_in: int, + dim_out: int, + bias: bool, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + self.proj = nnx.Linear( + rngs=rngs, + in_features=dim_in, + out_features=dim_out, + use_bias=bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.proj(x) + return x * jax.nn.sigmoid(1.702 * x) + + +class WanFeedForward(nnx.Module): + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim: int = None, + bias: bool = True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.act_fn = None + if activation_fn == "gelu-approximate": + self.act_fn = ApproximateGELU( + rngs=rngs, + dim_in=dim, + dim_out=inner_dim, + bias=bias, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + else: + raise NotImplementedError(f"{activation_fn} is not implemented.") + + self.proj_out = nnx.Linear( + rngs=rngs, + in_features=inner_dim, + out_features=dim_out, + use_bias=bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + hidden_states = self.act_fn(hidden_states) + return self.proj_out(hidden_states) + + class WanTransformerBlock(nnx.Module): def __init__( @@ -192,17 +276,107 @@ def __init__( qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, - added_kv_proj_dim: Optional[int] = None + # In torch, this is none, so it can be ignored. + # added_kv_proj_dim: Optional[int] = None, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + + # 1. Self-attention self.norm1 = FP32LayerNorm( + rngs=rngs, dim=dim, eps=eps, elementwise_affine=False ) + self.attn1 = FlaxWanAttention( + rngs=rngs, + query_dim=dim, + heads=num_heads, + dim_head= dim // num_heads, + qk_norm=qk_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention_kernel=attention + ) + + # 1. Cross-attention + self.attn2 = FlaxWanAttention( + rngs=rngs, + query_dim=dim, + heads=num_heads, + dim_head= dim // num_heads, + qk_norm=qk_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention_kernel=attention + ) + assert cross_attn_norm == True + self.norm2 = FP32LayerNorm( + rngs=rngs, + dim=dim, + eps=eps, + elementwise_affine=True + ) + + # 3. Feed-forward + self.ffn = WanFeedForward( + rngs=rngs, + dim=dim, + inner_dim=ffn_dim, + activation_fn="gelu-approximate", + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) + + key = rngs.params() + self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) - def __call__(self): - pass + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array, + temb: jax.Array, + rotary_emb: jax.Array + ): + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + ) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype) + return hidden_states class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -269,7 +443,22 @@ def __init__( # 3. Transformer blocks blocks = [] for _ in range(num_layers): - block = WanTransformerBlock() + block = WanTransformerBlock( + rngs=rngs, + dim=inner_dim, + ffn_dim=ffn_dim, + num_attention_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention=attention + ) blocks.append(block) self.blocks = blocks @@ -301,8 +490,9 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") - # for block in self.blocks: - + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + breakpoint() return hidden_states \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 0945c6416..55b2b9215 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -27,7 +27,7 @@ create_device_mesh, get_flash_block_sizes ) -from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding +from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection from ..models.normalization_flax import FP32LayerNorm from ..models.attention_flax import FlaxWanAttention @@ -119,6 +119,75 @@ def test_wan_time_text_embedding(self): assert timestep_proj.shape == (batch_size, time_proj_dim) assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + def test_wan_block(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + + dim=5120 + ffn_dim=13824 + num_heads=40 + qk_norm="rms_norm_across_heads" + cross_attn_norm=True + eps=1e-6 + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_dim = 75600 + + # for rotary post embed. + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + + wan_rot_embed = WanRotaryPosEmbed( + attention_head_dim=128, + patch_size=[1, 2, 2], + max_seq_len=1024 + ) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + assert dummy_rotary_emb.shape == (batch_size, 1, hidden_dim, 64) + + # for transformer block + dummy_hidden_states = jnp.ones((batch_size, hidden_dim, dim)) + + dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) + + dummy_temb = jnp.ones((batch_size, 6, dim)) + + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes + ) + + dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) + assert dummy_output.shape == dummy_hidden_states.shape + + + def test_wan_attention(self): pyconfig.initialize( [ From 440f39c2a4a69da4c9103f37f7669d91175beaa3 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 13 May 2025 22:59:38 +0000 Subject: [PATCH 32/54] wan transformer with in/out shapes verified --- .../wan/transformers/transformer_wan.py | 28 +++++++++-- .../tests/wan_transformer_test.py | 48 ++++++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 041b76916..d524b2b4e 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -15,10 +15,11 @@ """ from typing import Tuple, Optional, Dict, Union, Any +import math import jax import jax.numpy as jnp from flax import nnx -from .... import common_types, max_logging +from .... import common_types from ...modeling_flax_utils import FlaxModelMixin, get_activation from ....configuration_utils import ConfigMixin, register_to_config from ...embeddings_flax import ( @@ -447,7 +448,7 @@ def __init__( rngs=rngs, dim=inner_dim, ffn_dim=ffn_dim, - num_attention_heads=num_attention_heads, + num_heads=num_attention_heads, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, @@ -462,6 +463,20 @@ def __init__( blocks.append(block) self.blocks = blocks + self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) + self.proj_out = nnx.Linear( + rngs=rngs, + in_features=inner_dim, + out_features=out_channels * math.prod(patch_size), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + key = rngs.params() + self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5) + def __call__( self, hidden_states: jax.Array, @@ -492,7 +507,14 @@ def __call__( for block in self.blocks: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - breakpoint() + + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) + + hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) + hidden_states = self.proj_out(hidden_states) + # TODO - can this reshape happen in a single command? + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) + hidden_states = hidden_states.reshape(batch_size, num_frames, height, width, num_channels) return hidden_states \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 55b2b9215..4b48fe349 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -27,7 +27,9 @@ create_device_mesh, get_flash_block_sizes ) -from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock +from ..models.wan.transformers.transformer_wan import ( + WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock, WanModel +) from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection from ..models.normalization_flax import FP32LayerNorm from ..models.attention_flax import FlaxWanAttention @@ -256,7 +258,49 @@ def test_wan_attention(self): except NotImplementedError as e: pass - + def test_wan_model(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + wan_model = WanModel( + rngs=rngs, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + dummy_timestep = jnp.ones((batch_size)) + dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) + + dummy_output = wan_model( + hidden_states=dummy_hidden_states, + timestep=dummy_timestep, + encoder_hidden_states=dummy_encoder_hidden_states + ) + assert dummy_output.shape == hidden_states_shape if __name__ == "__main__": absltest.main() \ No newline at end of file From 0ef8c71a83e32089023aa2f2aebcfa4ffa49667e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 14 May 2025 21:37:58 +0000 Subject: [PATCH 33/54] load wan 2.1 transformer weights. --- src/maxdiffusion/configs/base_wan_14b.yml | 2 +- src/maxdiffusion/models/attention_flax.py | 16 +++--- src/maxdiffusion/models/wan/wan_utils.py | 63 ++++++++++++++++++++++- 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index d7a802225..6551902a7 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -27,7 +27,7 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 -pretrained_model_name_or_path: '' +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' unet_checkpoint: '' revision: '' diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 342ee22d2..927b5f504 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -618,7 +618,6 @@ def __init__( in_features=self.inner_dim, out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=qkv_bias, dtype=dtype, param_dtype=weights_dtype, precision=precision, @@ -629,7 +628,6 @@ def __init__( in_features=self.inner_dim, out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=qkv_bias, dtype=dtype, param_dtype=weights_dtype, precision=precision, @@ -640,7 +638,6 @@ def __init__( in_features=self.inner_dim, out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=qkv_bias, dtype=dtype, param_dtype=weights_dtype, precision=precision, @@ -651,16 +648,15 @@ def __init__( in_features=self.inner_dim, out_features=self.inner_dim, kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), - use_bias=qkv_bias, dtype=dtype, param_dtype=weights_dtype, precision=precision, ) - self.query_norm = None - self.key_norm = None + self.norm_q = None + self.norm_k = None if qk_norm is not None: - self.query_norm = nnx.RMSNorm( + self.norm_q = nnx.RMSNorm( num_features=self.inner_dim, rngs=rngs, epsilon=eps, @@ -669,7 +665,7 @@ def __init__( param_dtype=weights_dtype ) - self.key_norm = nnx.RMSNorm( + self.norm_k = nnx.RMSNorm( num_features=self.inner_dim, rngs=rngs, dtype=dtype, @@ -713,8 +709,8 @@ def __call__( value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) if self.qk_norm: - query_proj = self.query_norm(query_proj) - key_proj = self.key_norm(key_proj) + query_proj = self.norm_q(query_proj) + key_proj = self.norm_k(key_proj) if rotary_emb is not None: query_proj = _unflatten_heads(query_proj, self.heads) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 1a9948fdb..4f7effad6 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,9 +1,10 @@ +import json import jax import jax.numpy as jnp from maxdiffusion import max_logging from huggingface_hub import hf_hub_download from safetensors import safe_open -from flax.traverse_util import unflatten_dict +from flax.traverse_util import unflatten_dict, flatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) @@ -16,6 +17,66 @@ def _tuple_str_to_int(in_tuple): out_list.append(item) return tuple(out_list) +def rename_for_nnx(key): + new_key = key + if "norm_k" in key or "norm_q" in key: + new_key = key[:-1] + ("scale",) + return new_key + +def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + # download the index file for sharded models. + index_file_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json") + # open the index file. + with open(index_file_path, 'r') as f: + index_dict = json.load(f) + model_files = set() + for key in index_dict["weight_map"].keys(): + model_files.add(index_dict["weight_map"][key]) + + model_files = list(model_files) + tensors = {} + for model_file in model_files: + ckpt_shard_path = hf_hub_download( + pretrained_model_name_or_path, subfolder="transformer", filename=model_file + ) + # now get all the filenames for the model that need downloading + max_logging.log(f"Load and port Wan 2.1 transformer on {device}") + + if ckpt_shard_path is not None: + with safe_open(ckpt_shard_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + flattened_dict = flatten_dict(eval_shapes) + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = {} + for key in flattened_dict: + string_tuple = tuple([str(item) for item in key]) + random_flax_state_dict[string_tuple] = flattened_dict[key] + del flattened_dict + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") + renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") + renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] From 82b719e005efd391eae9d6b138902d66d3a6f3ef Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 20 May 2025 23:33:21 +0000 Subject: [PATCH 34/54] fix rope calculations. --- src/maxdiffusion/configs/base_wan_14b.yml | 4 +- src/maxdiffusion/models/attention_flax.py | 35 +++++++++-- src/maxdiffusion/models/embeddings_flax.py | 2 +- .../wan/transformers/transformer_wan.py | 63 ++++++++++++++----- 4 files changed, 81 insertions(+), 23 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 6551902a7..f668503f4 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -139,8 +139,8 @@ data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 # Dataset diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 927b5f504..4a62be819 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,6 +15,7 @@ import functools import math from typing import Optional, Callable, Tuple +import numpy as np import flax.linen as nn from flax import nnx import jax @@ -318,7 +319,7 @@ def _apply_attention( ): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) - + if attention_kernel == "flash": can_use_flash_attention = ( query.shape[1] >= flash_min_seq_length @@ -578,8 +579,7 @@ def __init__( qkv_bias: bool = False, quant: Quant = None, ): - - if attention_kernel == "cudnn_flash_te" or attention_kernel == "dot_product": + if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: @@ -676,7 +676,7 @@ def __init__( def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]: dtype = xq.dtype reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) - reshape_xk = xq.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) @@ -696,13 +696,19 @@ def __call__( encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None ) -> jax.Array: - + print(" -- -- WanAttention -- ") dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states query_proj = self.query(hidden_states) + print("query_proj min: ", np.min(query_proj)) + print("query_proj max: ", np.max(query_proj)) key_proj = self.key(encoder_hidden_states) + print("key_proj min: ", np.min(key_proj)) + print("key_proj max: ", np.max(key_proj)) value_proj = self.value(encoder_hidden_states) + print("value_proj min: ", np.min(value_proj)) + print("value_proj max: ", np.max(value_proj)) query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) @@ -711,18 +717,37 @@ def __call__( if self.qk_norm: query_proj = self.norm_q(query_proj) key_proj = self.norm_k(key_proj) + print("query_proj min: ", np.min(query_proj)) + print("query_proj max: ", np.max(query_proj)) + print("key_proj min: ", np.min(key_proj)) + print("key_proj max: ", np.max(key_proj)) if rotary_emb is not None: query_proj = _unflatten_heads(query_proj, self.heads) key_proj = _unflatten_heads(key_proj, self.heads) + # value_proj = _unflatten_heads(value_proj, self.heads) query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + print("Rope query_proj min: ", np.min(query_proj)) + print("Rope query_proj max: ", np.max(query_proj)) + print("Rope key_proj min: ", np.min(key_proj)) + print("Rope key_proj max: ", np.max(key_proj)) + #breakpoint() query_proj = _reshape_heads_to_head_dim(query_proj) key_proj = _reshape_heads_to_head_dim(key_proj) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + try: + print("attn_output min: ", np.min(attn_output)) + print("attn_output_for_print max: ", np.max(attn_output)) + except: + pass attn_output = attn_output.astype(dtype=dtype) hidden_states = self.proj_attn(hidden_states) + print("hidden_states min: ", np.min(hidden_states)) + print("hidden_states max: ", np.max(hidden_states)) + print(" -- -- WanAttention DONE -- ") + #breakpoint() return hidden_states diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 38a633e29..ad2283034 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -227,7 +227,7 @@ def get_1d_rotary_pos_embed( out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1) else: # Wan 2.1 - out = jax.lax.complex(jnp.ones_like(freqs), freqs) + out = jax.lax.complex(jnp.cos(freqs), jnp.sin(freqs)) return out class NNXPixArtAlphaTextProjection(nnx.Module): diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index d524b2b4e..da479d237 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from flax import nnx +import numpy as np from .... import common_types from ...modeling_flax_utils import FlaxModelMixin, get_activation from ....configuration_utils import ConfigMixin, register_to_config @@ -58,12 +59,7 @@ def __init__( use_real=False ) freqs.append(freq) - self.freqs = jnp.concatenate(freqs, axis=1) - - def __call__(self, hidden_states: jax.Array) -> jax.Array: - _, num_frames, height, width, _ = hidden_states.shape - p_t, p_h, p_w = self.patch_size - ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + freqs = jnp.concatenate(freqs, axis=1) sizes = [ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), @@ -72,16 +68,21 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array: ] cumulative_sizes = jnp.cumsum(jnp.array(sizes)) split_indices = cumulative_sizes[:-1] - freqs_split = jnp.split(self.freqs, split_indices, axis=1) + self.freqs_split = jnp.split(freqs, split_indices, axis=1) + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + _, num_frames, height, width, _ = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) - freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) + freqs_f = jnp.expand_dims(jnp.expand_dims(self.freqs_split[0][:ppf], axis=1), axis=1) + freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, self.freqs_split[0].shape[-1])) - freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) - freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) + freqs_h = jnp.expand_dims(jnp.expand_dims(self.freqs_split[1][:pph], axis=0), axis=2) + freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, self.freqs_split[1].shape[-1])) - freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) - freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) + freqs_w = jnp.expand_dims(jnp.expand_dims(self.freqs_split[2][:ppw], axis=0), axis=1) + freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, self.freqs_split[2].shape[-1])) freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) @@ -361,22 +362,41 @@ def __call__( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) + print("Wan Block -- START -- ") # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) + print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) + print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + print("Wan Block -- Self Attn. attn_output, min: ", np.min(attn_output)) + print("Wan Block -- Self Attn. attn_output, max: ", np.max(attn_output)) hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) + print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) + print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)) + print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) + print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + print("Wan Block -- Cross Attn. attn_output, min: ", np.min(attn_output)) + print("Wan Block -- Cross Attn. attn_output, max: ", np.max(attn_output)) hidden_states = hidden_states + attn_output + print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) + print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) # 3. Feed-forward norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) - + print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) + print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) ff_output = self.ffn(norm_hidden_states) + print("Wan Block -- ff_output, min: ", np.min(ff_output)) + print("Wan Block -- ff_output, max: ", np.max(ff_output)) hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype) + print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) + print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) + print("Wan Block -- COMPLETE -- ") return hidden_states @@ -495,11 +515,22 @@ def __call__( rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) + print("***** After patch embedding") + print("hidden_states, min: ", np.min(hidden_states)) + print("hidden_states, max: ", np.max(hidden_states)) hidden_states = jax.lax.collapse(hidden_states, 1, -1) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) + print("***** After condition embedder") + print("temb, min: ", np.min(temb)) + print("temb, max: ", np.max(temb)) + print("timestep_proj, min: ", np.min(timestep_proj)) + print("timestep_proj, max: ", np.max(timestep_proj)) + print("encoder_hidden_states min: ", np.min(encoder_hidden_states)) + print("encoder_hidden_states max: ", np.max(encoder_hidden_states)) + timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: @@ -507,7 +538,9 @@ def __call__( for block in self.blocks: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - + print("After block, hidden_states min:", np.min(hidden_states)) + print("After block, hidden_states max:", np.max(hidden_states)) + #breakpoint() shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) From 38bea205272b18efee4dca6b34056a2fbf3685c2 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 22 May 2025 21:02:07 +0000 Subject: [PATCH 35/54] fix gelu block. --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/configs/base_wan_t2v.yml | 269 ------------------ .../wan/transformers/transformer_wan.py | 133 ++++----- 3 files changed, 59 insertions(+), 344 deletions(-) delete mode 100644 src/maxdiffusion/configs/base_wan_t2v.yml diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f668503f4..e905787b5 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -102,6 +102,7 @@ base_output_directory: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml deleted file mode 100644 index 28ef6e77e..000000000 --- a/src/maxdiffusion/configs/base_wan_t2v.yml +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This sentinel is a reminder to choose a real run name. -run_name: '' - -metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. -# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ -write_metrics: True -gcs_metrics: False -# If true save config to GCS in {base_output_directory}/{run_name}/ -save_config_to_gcs: False -log_period: 100 - -pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' - -# Flux params -flux_name: "flux-dev" -max_sequence_length: 512 -time_shift: True -base_shift: 0.5 -max_shift: 1.15 -# offloads t5 encoder after text encoding to save memory. -offload_encoders: True - - -unet_checkpoint: '' -revision: 'refs/pr/95' -# This will convert the weights to this dtype. -# When running inference on TPUv5e, use weights_dtype: 'bfloat16' -weights_dtype: 'bfloat16' -# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) -activations_dtype: 'bfloat16' - -# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision -# Options are "DEFAULT", "HIGH", "HIGHEST" -# fp32 activations and fp32 weights with HIGHEST will provide the best precision -# at the cost of time. -precision: "DEFAULT" - -# if False state is not jitted and instead replicate is called. This is good for debugging on single host -# It must be True for multi-host. -jit_initializers: True - -# Set true to load weights from pytorch -from_pt: True -split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te - -flash_block_sizes: {} -# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. -# flash_block_sizes: { -# "block_q" : 1536, -# "block_kv_compute" : 1536, -# "block_kv" : 1536, -# "block_q_dkv" : 1536, -# "block_kv_dkv" : 1536, -# "block_kv_dkv_compute" : 1536, -# "block_q_dq" : 1536, -# "block_kv_dq" : 1536 -# } -# GroupNorm groups -norm_num_groups: 32 - -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch -# else they will be loaded from pretrained_model_name_or_path -train_new_unet: False - -# train text_encoder - Currently not supported for SDXL -train_text_encoder: False -text_encoder_learning_rate: 4.25e-6 - -# https://arxiv.org/pdf/2305.08891.pdf -snr_gamma: -1.0 - -timestep_bias: { - # a value of later will increase the frequence of the model's final training steps. - # none, earlier, later, range - strategy: "none", - # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. - multiplier: 1.0, - # when using strategy=range, the beginning (inclusive) timestep to bias. - begin: 0, - # when using strategy=range, the final step (inclusive) to bias. - end: 1000, - # portion of timesteps to bias. - # 0.5 will bias one half of the timesteps. Value of strategy determines - # whether the biased portions are in the earlier or later timesteps. - portion: 0.25 -} - -# Override parameters from checkpoints's scheduler. -diffusion_scheduler_config: { - _class_name: 'FlaxEulerDiscreteScheduler', - prediction_type: 'epsilon', - rescale_zero_terminal_snr: False, - timestep_spacing: 'trailing' -} - -# Output directory -# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" -base_output_directory: "" - -# Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' - -# Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] - -# batch : batch dimension of data and activations -# hidden : -# embed : attention qkv dense layer hidden dim named as embed -# heads : attention head dim = num_heads * head_dim -# length : attention sequence length -# temb_in : dense.shape[0] of resnet dense before conv -# out_c : dense.shape[1] of resnet dense before conv -# out_channels : conv.shape[-1] activation -# keep_1 : conv.shape[0] weight -# keep_2 : conv.shape[1] weight -# conv_in : conv.shape[2] weight -# conv_out : conv.shape[-1] weight -logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], - ['activation_kv', 'tensor'], - ['mlp','tensor'], - ['embed','fsdp'], - ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], - ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], - ] -data_sharding: [['data', 'fsdp', 'tensor']] - -# One axis for each parallelism type may hold a placeholder (-1) -# value to auto-shard based on available slices and devices. -# By default, product of the DCN axes should equal number of slices -# and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 -dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded -ici_tensor_parallelism: 1 - -# Dataset -# Replace with dataset path or train_data_dir. One has to be set. -dataset_name: 'diffusers/pokemon-gpt4-captions' -train_split: 'train' -dataset_type: 'tf' -cache_latents_text_encoder_outputs: True -# cache_latents_text_encoder_outputs only apply to dataset_type="tf", -# only apply to small dataset that fits in memory -# prepare image latents and text encoder outputs -# Reduce memory consumption and reduce step time during training -# transformed dataset is saved at dataset_save_location -dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' -train_data_dir: '' -dataset_config_name: '' -jax_cache_dir: '' -hf_data_dir: '' -hf_train_files: '' -hf_access_token: '' -image_column: 'image' -caption_column: 'text' -resolution: 1024 -center_crop: False -random_flip: False -# If cache_latents_text_encoder_outputs is True -# the num_proc is set to 1 -tokenize_captions_num_proc: 4 -transform_images_num_proc: 4 -reuse_example_batch: False -enable_data_shuffling: True - -# checkpoint every number of samples, -1 means don't checkpoint. -checkpoint_every: -1 -# enables one replica to read the ckpt then broadcast to the rest -enable_single_replica_ckpt_restoring: False - -# Training loop -learning_rate: 4.e-7 -scale_lr: False -max_train_samples: -1 -# max_train_steps takes priority over num_train_epochs. -max_train_steps: 200 -num_train_epochs: 1 -seed: 0 -output_dir: 'sdxl-model-finetuned' -per_device_batch_size: 1 - -warmup_steps_fraction: 0.0 -learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. - -# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before -# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. - -# AdamW optimizer parameters -adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. -adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. -adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 1.e-2 # AdamW Weight decay -max_grad_norm: 1.0 - -enable_profiler: False -# Skip first n steps for profiling, to omit things like compilation and to give -# the iteration time a chance to stabilize. -skip_first_n_steps_for_profiler: 5 -profiler_steps: 10 - -# Generation parameters -prompt: "A magical castle in the middle of a forest, artistic drawing" -prompt_2: "A magical castle in the middle of a forest, artistic drawing" -negative_prompt: "purple, red" -do_classifier_free_guidance: True -guidance_scale: 3.5 -# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf -guidance_rescale: 0.0 -num_inference_steps: 50 - -# SDXL Lightning parameters -lightning_from_pt: True -# Empty or "ByteDance/SDXL-Lightning" to enable lightning. -lightning_repo: "" -# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. -lightning_ckpt: "" - -# LoRA parameters -# Values are lists to support multiple LoRA loading during inference in the future. -lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], - from_pt: [] -} -# Ex with values: -# lora_config : { -# lora_model_name_or_path: ["ByteDance/Hyper-SD"], -# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], -# adapter_name: ["hyper-sdxl"], -# scale: [0.7], -# from_pt: [True] -# } - -enable_mllog: False - -#controlnet -controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' -controlnet_from_pt: True -controlnet_conditioning_scale: 0.5 -controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' -quantization: '' -# Shard the range finding operation for quantization. By default this is set to number of slices. -quantization_local_shard_count: -1 -compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. - diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index da479d237..659501140 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -145,45 +145,6 @@ def __call__( return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image - -class WanTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin): - def __init__( - self, - rngs: nnx.Rngs, - patch_size: Tuple[int] = (1, 2, 2), - num_attention_heads: int = 40, - attention_head_dim: int = 128, - in_channels: int = 16, - out_channels: int = 16, - text_dim: int = 4096, - freq_dim: int = 256, - ffn_dim: int = 13824, - num_layers: int = 40, - cross_attn_norm: bool = True, - qk_norm: Optional[str] = "rms_norm_across_heads", - eps: float = 1e-6, - image_dim: Optional[int] = None, - added_kv_proj_dim: Optional[int] = None, - rope_max_seq_len: int = 1024, - pos_embed_seq_len: Optional[int] = None, - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - attention: str = "dot_product", - ): - inner_dim = num_attention_heads * attention_head_dim - out_channels = out_channels or in_channels - - #1. Patch & position embedding - self.rope = WanRotaryPosEmbed( - attention_head_dim, - patch_size, - rope_max_seq_len - ) - class ApproximateGELU(nnx.Module): r""" The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this @@ -213,7 +174,7 @@ def __init__( def __call__(self, x: jax.Array) -> jax.Array: x = self.proj(x) - return x * jax.nn.sigmoid(1.702 * x) + return nnx.gelu(x) class WanFeedForward(nnx.Module): @@ -362,41 +323,54 @@ def __call__( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) - print("Wan Block -- START -- ") + # print("Wan Block -- START -- ") + # print("shift_msa min: ", shift_msa.min()) + # print("shift_msa max: ", shift_msa.max()) + # print("scale_msa min: ", scale_msa.min()) + # print("scale_msa max: ", scale_msa.max()) + # print("gate_msa min: ", gate_msa.min()) + # print("gate_msa max: ", gate_msa.max()) + # print("c_shift_msa min: ", c_shift_msa.min()) + # print("c_shift_msa max: ", c_shift_msa.max()) + # print("c_scale_msa min: ", c_scale_msa.min()) + # print("c_scale_msa max: ", c_scale_msa.max()) + # print("c_gate_msa min: ", c_gate_msa.min()) + # print("c_gate_msa max: ", c_gate_msa.max()) + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) - print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) - print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) + # print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) + # print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) - print("Wan Block -- Self Attn. attn_output, min: ", np.min(attn_output)) - print("Wan Block -- Self Attn. attn_output, max: ", np.max(attn_output)) + # print("Wan Block -- Self Attn. attn_output, min: ", np.min(attn_output)) + # print("Wan Block -- Self Attn. attn_output, max: ", np.max(attn_output)) hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) - print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) - print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) + # print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) + # print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)) - print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) - print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) + # print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) + # print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) - print("Wan Block -- Cross Attn. attn_output, min: ", np.min(attn_output)) - print("Wan Block -- Cross Attn. attn_output, max: ", np.max(attn_output)) + # print("Wan Block -- Cross Attn. attn_output, min: ", np.min(attn_output)) + # print("Wan Block -- Cross Attn. attn_output, max: ", np.max(attn_output)) hidden_states = hidden_states + attn_output - print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) - print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) + # print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) + # print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) # 3. Feed-forward norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) - print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) - print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) + # print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) + # print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) ff_output = self.ffn(norm_hidden_states) - print("Wan Block -- ff_output, min: ", np.min(ff_output)) - print("Wan Block -- ff_output, max: ", np.max(ff_output)) + # print("Wan Block -- ff_output, min: ", np.min(ff_output)) + # print("Wan Block -- ff_output, max: ", np.max(ff_output)) hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype) - print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) - print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) - print("Wan Block -- COMPLETE -- ") + # print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) + # print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) + # print("Wan Block -- COMPLETE -- ") return hidden_states @@ -431,7 +405,6 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", ): - inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -515,39 +488,49 @@ def __call__( rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) - print("***** After patch embedding") - print("hidden_states, min: ", np.min(hidden_states)) - print("hidden_states, max: ", np.max(hidden_states)) + # print("***** After patch embedding") + # print("hidden_states, min: ", np.min(hidden_states)) + # print("hidden_states, max: ", np.max(hidden_states)) hidden_states = jax.lax.collapse(hidden_states, 1, -1) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - print("***** After condition embedder") - print("temb, min: ", np.min(temb)) - print("temb, max: ", np.max(temb)) - print("timestep_proj, min: ", np.min(timestep_proj)) - print("timestep_proj, max: ", np.max(timestep_proj)) - print("encoder_hidden_states min: ", np.min(encoder_hidden_states)) - print("encoder_hidden_states max: ", np.max(encoder_hidden_states)) + # print("***** After condition embedder") + # print("temb, min: ", np.min(temb)) + # print("temb, max: ", np.max(temb)) + # print("timestep_proj, min: ", np.min(timestep_proj)) + # print("timestep_proj, max: ", np.max(timestep_proj)) + # print("encoder_hidden_states min: ", np.min(encoder_hidden_states)) + # print("encoder_hidden_states max: ", np.max(encoder_hidden_states)) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") - for block in self.blocks: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - print("After block, hidden_states min:", np.min(hidden_states)) - print("After block, hidden_states max:", np.max(hidden_states)) + # print("After block, hidden_states min:", np.min(hidden_states)) + # print("After block, hidden_states max:", np.max(hidden_states)) + # jax.debug.print("after block, hidden_states min: {x}", x=hidden_states.min()) + # jax.debug.print("after block, hidden_states min: {x}", x=hidden_states.max()) #breakpoint() shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) + # print("shift min: ", shift.min()) + # print("shift.max: ", shift.max()) + # print("scale.min: ", scale.min()) + # print("scale.max: ", scale.max()) hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) + # print("hidden_states.min: ", hidden_states.min()) + # print("hidden_states.max: ", hidden_states.max()) hidden_states = self.proj_out(hidden_states) + # print("After proj_out -- hidden_states.min: ", hidden_states.min()) + # print("After proj_out -- hidden_states.max: ", hidden_states.max()) # TODO - can this reshape happen in a single command? hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) hidden_states = hidden_states.reshape(batch_size, num_frames, height, width, num_channels) - + # jax.debug.print("FINAL hidden_states min: {x}", x=hidden_states.min()) + # jax.debug.print("FINAL hidden_states max: {x}", x=hidden_states.max()) return hidden_states \ No newline at end of file From 716598bc7837e757c63e377905292382d2d2bd6a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 23 May 2025 22:07:38 +0000 Subject: [PATCH 36/54] wip - building pipeline and gen code. --- src/maxdiffusion/configs/base_wan_14b.yml | 3 +- src/maxdiffusion/generate_wan.py | 28 +++ .../pipelines/wan/wan_pipeline.py | 224 ++++++++++++++++++ 3 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 src/maxdiffusion/generate_wan.py create mode 100644 src/maxdiffusion/pipelines/wan/wan_pipeline.py diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index e905787b5..8edd8237d 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -217,8 +217,9 @@ do_classifier_free_guidance: True guidance_scale: 3.5 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 50 +num_inference_steps: 30 save_final_checkpoint: False +flow_shift: 5.0 # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py new file mode 100644 index 000000000..06aef79a3 --- /dev/null +++ b/src/maxdiffusion/generate_wan.py @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion import pyconfig +from absl import app + +def run(config): + pipeline = WanPipeline.from_pretrained(config) + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py new file mode 100644 index 000000000..47d821ee6 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -0,0 +1,224 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union, Optional +import numpy as np +import jax +from jax.sharding import Mesh, PositionalSharding +from flax import nnx +from ...pyconfig import HyperParameters +from ... import max_utils +from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae +from ...models.wan.transformers.transformer_wan import WanModel +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache +from maxdiffusion.video_processor import VideoProcessor +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState +from transformers import AutoTokenizer, UMT5EncoderModel +import ftfy +import html +import re +import torch + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + +class WanPipeline: + r""" + Pipeline for text-to-video generation using Wan. + + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlaxUniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanModel, + vae: AutoencoderKLWan, + vae_cache: AutoencoderKLWanCache, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state: UniPCMultistepSchedulerState, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters + ): + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.transformer = transformer + self.vae = vae + self.vae_cache = vae_cache + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.devices_array = devices_array + self.mesh = mesh + self.config = config + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + @classmethod + def load_vae(cls, rngs: nnx.Rngs, config: HyperParameters): + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + + graphdef, state = nnx.split(wan_vae, nnx.Param) + params = state.to_pure_dict() + # This replaces random params with the model. + params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + wan_vae = nnx.merge(graphdef, params) + + return wan_vae, vae_cache + + @classmethod + def load_text_encoder(cls, config: HyperParameters): + text_encoder = UMT5EncoderModel.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="text_encoder", + ) + return text_encoder + + @classmethod + def load_tokenizer(cls, config: HyperParameters): + tokenizer = AutoTokenizer.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="tokenizer", + ) + return tokenizer + + @classmethod + def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + wan_transformer = WanModel.from_config( + config.pretrained_model_name_or_path, + subfolder="transformer", + rngs=rngs, + attention=config.attention, + mesh=mesh, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype + ) + graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) + params = state.to_pure_dict() + del state + params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + params = jax.device_put(params, PositionalSharding(devices_array).replicate()) + wan_transformer = nnx.merge(graphdef, params, rest_of_state) + return wan_transformer + + @classmethod + def load_scheduler(cls, config): + scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="scheduler", + flow_shift=config.flow_shift # 5.0 for 720p, 3.0 for 480p + ) + return scheduler, scheduler_state + + @classmethod + def from_pretrained(cls, config: HyperParameters): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + + wan_vae, vae_cache = cls.load_vae(rngs=rngs, config=config) + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + scheduler, scheduler_state = cls.load_scheduler(config=config) + + return WanPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=wan_vae, + vae_cache=vae_cache, + scheduler=scheduler, + scheduler_state=scheduler_state, + devices_array=devices_array, + mesh=mesh, + config=config + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + + + From 6973222aaa9cf9377f792ab078efa8e31de4af01 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 27 May 2025 15:45:15 +0000 Subject: [PATCH 37/54] initial wan pipeline for txt2vid. Not currently working. --- src/maxdiffusion/configs/base_wan_14b.yml | 11 +- src/maxdiffusion/generate_wan.py | 10 ++ .../pipelines/wan/wan_pipeline.py | 153 ++++++++++++++++++ 3 files changed, 170 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 8edd8237d..c9b5a20cf 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -210,11 +210,14 @@ skip_first_n_steps_for_profiler: 5 profiler_steps: 10 # Generation parameters -prompt: "A magical castle in the middle of a forest, artistic drawing" -prompt_2: "A magical castle in the middle of a forest, artistic drawing" -negative_prompt: "purple, red" +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -guidance_scale: 3.5 +height: 720 +width: 1280 +num_frames: 81 +guidance_scale: 5.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 06aef79a3..3ee5d6c3f 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -20,6 +20,16 @@ def run(config): pipeline = WanPipeline.from_pretrained(config) + pipeline( + prompt=config.prompt, + negative_prompt=config.negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 47d821ee6..e56c3bae1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -13,11 +13,14 @@ # limitations under the License. from typing import List, Union, Optional +from functools import partial import numpy as np import jax +import jax.numpy as jnp from jax.sharding import Mesh, PositionalSharding from flax import nnx from ...pyconfig import HyperParameters +from ... import max_logging from ... import max_utils from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae from ...models.wan.transformers.transformer_wan import WanModel @@ -219,6 +222,156 @@ def encode_prompt( num_videos_per_prompt: int = 1, max_sequence_length: int = 226, ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + + prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) + negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size: int, + vae_scale_factor_temporal: int, + vae_scale_factor_spatial: int, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_channels_latents: int = 16, + ): + rng = jax.random.key(self.config.seed) + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // vae_scale_factor_spatial, + int(width) // vae_scale_factor_spatial, + num_channels_latents + ) + latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype) + + return latents + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512 + ): + if num_frames % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length + ) + + num_channel_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size, + vae_scale_factor_temporal=self.vae_scale_factor_temporal, + vae_scale_factor_spatial=self.vae_scale_factor_spatial, + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channel_latents + ) + prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) + negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) + latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate()) + prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate()) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate()) + + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, num_inference_steps=self.config.num_inference_steps, shape=latents.shape + ) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference, + guidance_scale=self.config.guidance_scale, + num_inference_steps=self.config.num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state + ) + with self.mesh: + latent = p_run_inference( + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds + ) + + +@partial(jax.jit, static_argnums=(6, 7, 8)) +def run_inference( + graphdef, + sharded_state, + rest_of_state, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale: float, + num_inference_steps: int, + scheduler : FlaxUniPCMultistepScheduler, + scheduler_state): + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + do_classifier_free_guidance = guidance_scale > 1.0 + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents.shape[0]) + + noise_pred = wan_transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False + )[0] + + if do_classifier_free_guidance: + noise_uncond = wan_transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + return_dict=False + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + return latents \ No newline at end of file From 0731a49a123376f24f0cd61fb4061a157274a25e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 29 May 2025 01:11:38 +0000 Subject: [PATCH 38/54] add sharding annotations for vae. Verified transformer correctness for one step. --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/models/attention_flax.py | 51 +-- src/maxdiffusion/models/embeddings_flax.py | 2 +- .../models/wan/autoencoder_kl_wan.py | 364 ++++++++++++++++-- .../wan/transformers/transformer_wan.py | 56 --- .../pipelines/wan/wan_pipeline.py | 175 ++++++--- src/maxdiffusion/tests/wan_vae_test.py | 89 ++++- 7 files changed, 552 insertions(+), 186 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index c9b5a20cf..1f5920a7b 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -130,6 +130,7 @@ logical_axis_rules: [ ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] data_sharding: [['data', 'fsdp', 'tensor']] diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 4a62be819..f87627a87 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -101,7 +101,8 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. """ - tensor = _unflatten_heads(tensor, heads) + if tensor.ndim != 4: + tensor = _unflatten_heads(tensor, heads) # pad head_dim to 128 if less than that. kv_size = tensor.shape[-1] @@ -319,12 +320,14 @@ def _apply_attention( ): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) - + seq_len_idx = 1 + if query.ndim == 4: + seq_len_idx = 2 if attention_kernel == "flash": can_use_flash_attention = ( - query.shape[1] >= flash_min_seq_length - and key.shape[1] >= flash_min_seq_length - and value.shape[1] >= flash_min_seq_length + query.shape[seq_len_idx] >= flash_min_seq_length + and key.shape[seq_len_idx] >= flash_min_seq_length + and value.shape[seq_len_idx] >= flash_min_seq_length ) else: can_use_flash_attention = True @@ -584,7 +587,6 @@ def __init__( if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") - self.dim_head = dim_head self.heads = heads self.inner_dim = dim_head * heads @@ -681,7 +683,6 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) - freqs_cis = freqs_cis[None, None, ...] xq_out_complex = xq_ * freqs_cis xk_out_complex = xk_ * freqs_cis @@ -696,58 +697,26 @@ def __call__( encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None ) -> jax.Array: - print(" -- -- WanAttention -- ") dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states query_proj = self.query(hidden_states) - print("query_proj min: ", np.min(query_proj)) - print("query_proj max: ", np.max(query_proj)) key_proj = self.key(encoder_hidden_states) - print("key_proj min: ", np.min(key_proj)) - print("key_proj max: ", np.max(key_proj)) value_proj = self.value(encoder_hidden_states) - print("value_proj min: ", np.min(value_proj)) - print("value_proj max: ", np.max(value_proj)) - - query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) - key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) - value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) if self.qk_norm: query_proj = self.norm_q(query_proj) key_proj = self.norm_k(key_proj) - print("query_proj min: ", np.min(query_proj)) - print("query_proj max: ", np.max(query_proj)) - print("key_proj min: ", np.min(key_proj)) - print("key_proj max: ", np.max(key_proj)) - if rotary_emb is not None: query_proj = _unflatten_heads(query_proj, self.heads) key_proj = _unflatten_heads(key_proj, self.heads) - # value_proj = _unflatten_heads(value_proj, self.heads) + value_proj = _unflatten_heads(value_proj, self.heads) query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - print("Rope query_proj min: ", np.min(query_proj)) - print("Rope query_proj max: ", np.max(query_proj)) - print("Rope key_proj min: ", np.min(key_proj)) - print("Rope key_proj max: ", np.max(key_proj)) - #breakpoint() - query_proj = _reshape_heads_to_head_dim(query_proj) - key_proj = _reshape_heads_to_head_dim(key_proj) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - try: - print("attn_output min: ", np.min(attn_output)) - print("attn_output_for_print max: ", np.max(attn_output)) - except: - pass attn_output = attn_output.astype(dtype=dtype) - hidden_states = self.proj_attn(hidden_states) - print("hidden_states min: ", np.min(hidden_states)) - print("hidden_states max: ", np.max(hidden_states)) - print(" -- -- WanAttention DONE -- ") - #breakpoint() + hidden_states = self.proj_attn(attn_output) return hidden_states diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index ad2283034..ef57aaf63 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -227,7 +227,7 @@ def get_1d_rotary_pos_embed( out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1) else: # Wan 2.1 - out = jax.lax.complex(jnp.cos(freqs), jnp.sin(freqs)) + out = jnp.exp(1j * freqs) return out class NNXPixArtAlphaTextProjection(nnx.Module): diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 9c92e2ee2..c80f88b01 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -51,6 +51,10 @@ def __init__( stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, use_bias: bool = True, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") self.stride = _canonicalize_tuple(stride, 3, "stride") @@ -67,6 +71,12 @@ def __init__( # Store the amount of padding needed *before* the depth dimension for caching logic self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + # Set sharding dynamically based on out_channels. + num_fsdp_axis_devices = mesh.device_ids.shape[1] + kernel_sharding = (None, None, None, None, None) + if out_channels % num_fsdp_axis_devices == 0: + kernel_sharding = (None, None, None, None, "conv_out") + self.conv = nnx.Conv( in_features=in_channels, out_features=out_channels, @@ -75,6 +85,12 @@ def __init__( use_bias=use_bias, padding="VALID", # Handle padding manually rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), + kernel_sharding + ), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision ) def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: @@ -175,8 +191,24 @@ def __init__( rngs: nnx.Rngs, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): - self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) + self.conv = nnx.Conv( + dim, dim, + kernel_size=kernel_size, + strides=stride, + use_bias=True, + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), + (None, None, None, None) + ), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision + ) def __call__(self, x): return self.conv(x) @@ -189,6 +221,10 @@ def __init__( dim: int, mode: str, rngs: nnx.Rngs, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.dim = dim self.mode = mode @@ -204,6 +240,12 @@ def __init__( padding="SAME", use_bias=True, rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), + (None, None, None, "conv_out") + ), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision ), ) elif mode == "upsample3d": @@ -216,6 +258,12 @@ def __init__( padding="SAME", use_bias=True, rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), + (None, None, None, "conv_out") + ), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision ), ) self.time_conv = WanCausalConv3d( @@ -224,13 +272,44 @@ def __init__( out_channels=dim * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) elif mode == "downsample2d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(3, 3), + stride=(2, 2), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) elif mode == "downsample3d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(3, 3), + stride=(2, 2), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) self.time_conv = WanCausalConv3d( - rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + rngs=rngs, + in_channels=dim, + out_channels=dim, + kernel_size=(3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) else: self.resample = Identity() @@ -292,16 +371,48 @@ def __init__( rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.nonlinearity = get_activation(non_linearity) # layers self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) - self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.conv1 = WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) - self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.conv2 = WanCausalConv3d(rngs=rngs, + in_channels=out_dim, + out_channels=out_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) self.conv_shortcut = ( - WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1) + WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) if in_dim != out_dim else Identity() ) @@ -344,11 +455,35 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): class WanAttentionBlock(nnx.Module): - def __init__(self, dim: int, rngs: nnx.Rngs): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): self.dim = dim self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) - self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs) - self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs) + self.to_qkv = nnx.Conv( + in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), + (None, None, None, "conv_out") + ), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision + ) + self.proj = nnx.Conv( + in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), + (None, None, "conv_in", None) + ), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision + ) def __call__(self, x: jax.Array): @@ -362,7 +497,6 @@ def __call__(self, x: jax.Array): # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - # q, k, v = jnp.split(qkv, 3, axis=-1) q, k, v = jnp.split(qkv, 3, axis=-2) q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) @@ -380,13 +514,56 @@ def __call__(self, x: jax.Array): class WanMidBlock(nnx.Module): - def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + def __init__( + self, dim: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): self.dim = dim - resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)] + resnets = [ + WanResidualBlock( + in_dim=dim, + out_dim=dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + ] attentions = [] for _ in range(num_layers): - attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) - resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)) + attentions.append( + WanAttentionBlock( + dim=dim, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + ) + resnets.append( + WanResidualBlock( + in_dim=dim, + out_dim=dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + ) self.attentions = attentions self.resnets = resnets @@ -410,6 +587,10 @@ def __init__( dropout: float = 0.0, upsample_mode: Optional[str] = None, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): # Create layers list resnets = [] @@ -417,7 +598,17 @@ def __init__( current_dim = in_dim for _ in range(num_res_blocks + 1): resnets.append( - WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs) + WanResidualBlock( + in_dim=current_dim, + out_dim=out_dim, + dropout=dropout, + non_linearity=non_linearity, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) ) current_dim = out_dim self.resnets = resnets @@ -425,7 +616,17 @@ def __init__( # Add upsampling layer if needed. self.upsamplers = None if upsample_mode is not None: - self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)] + self.upsamplers = [ + WanResample( + dim=out_dim, + mode=upsample_mode, + rngs=rngs, + mesh=mesh, + weights_dtype=weights_dtype, + dtype=dtype, + precision=precision + ) + ] def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): for resnet in self.resnets: @@ -455,6 +656,10 @@ def __init__( temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.dim = dim self.z_dim = z_dim @@ -475,6 +680,10 @@ def __init__( out_channels=dims[0], kernel_size=3, padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) # downsample blocks @@ -482,15 +691,44 @@ def __init__( for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks for _ in range(num_res_blocks): - self.down_blocks.append(WanResidualBlock(in_dim=in_dim, out_dim=out_dim, dropout=dropout, rngs=rngs)) + self.down_blocks.append( + WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout,rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + ) if scale in attn_scales: - self.down_blocks.append(WanAttentionBlock(dim=out_dim, rngs=rngs)) + self.down_blocks.append( + WanAttentionBlock( + dim=out_dim, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + ) in_dim = out_dim # downsample block if i != len(dim_mult) - 1: mode = "downsample3d" if temperal_downsample[i] else "downsample2d" - self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) + self.down_blocks.append( + WanResample( + out_dim, + mode=mode, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + ) scale /= 2.0 # middle_blocks @@ -500,11 +738,25 @@ def __init__( dropout=dropout, non_linearity=non_linearity, num_layers=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) # output blocks self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs) - self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=z_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: @@ -567,6 +819,10 @@ def __init__( temperal_upsample=[False, True, True], dropout=0.0, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.dim = dim self.z_dim = z_dim @@ -582,10 +838,29 @@ def __init__( scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block - self.conv_in = WanCausalConv3d(rngs=rngs, in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1) + self.conv_in = WanCausalConv3d(rngs=rngs, + in_channels=z_dim, + out_channels=dims[0], + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) # middle_blocks - self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) + self.mid_block = WanMidBlock( + dim=dims[0], + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) # upsample blocks self.up_blocks = [] @@ -607,6 +882,10 @@ def __init__( upsample_mode=upsample_mode, non_linearity=non_linearity, rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) self.up_blocks.append(up_block) @@ -616,7 +895,16 @@ def __init__( # output blocks self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) - self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=3, kernel_size=3, padding=1) + self.conv_out = WanCausalConv3d(rngs=rngs, + in_channels=out_dim, + out_channels=3, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: @@ -728,6 +1016,10 @@ def __init__( 2.8251, 1.9160, ], + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.z_dim = z_dim self.temperal_downsample = temperal_downsample @@ -744,13 +1036,31 @@ def __init__( attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision + ) + self.quant_conv = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim * 2, + out_channels=z_dim * 2, + kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) - self.quant_conv = WanCausalConv3d(rngs=rngs, in_channels=z_dim * 2, out_channels=z_dim * 2, kernel_size=1) + self.post_quant_conv = WanCausalConv3d( rngs=rngs, in_channels=z_dim, out_channels=z_dim, kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) self.decoder = WanDecoder3d( @@ -762,6 +1072,10 @@ def __init__( attn_scales=attn_scales, temperal_upsample=self.temporal_upsample, dropout=dropout, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision ) def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 659501140..8727242d5 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -323,54 +323,22 @@ def __call__( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) - # print("Wan Block -- START -- ") - # print("shift_msa min: ", shift_msa.min()) - # print("shift_msa max: ", shift_msa.max()) - # print("scale_msa min: ", scale_msa.min()) - # print("scale_msa max: ", scale_msa.max()) - # print("gate_msa min: ", gate_msa.min()) - # print("gate_msa max: ", gate_msa.max()) - # print("c_shift_msa min: ", c_shift_msa.min()) - # print("c_shift_msa max: ", c_shift_msa.max()) - # print("c_scale_msa min: ", c_scale_msa.min()) - # print("c_scale_msa max: ", c_scale_msa.max()) - # print("c_gate_msa min: ", c_gate_msa.min()) - # print("c_gate_msa max: ", c_gate_msa.max()) # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) - # print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) - # print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) - # print("Wan Block -- Self Attn. attn_output, min: ", np.min(attn_output)) - # print("Wan Block -- Self Attn. attn_output, max: ", np.max(attn_output)) hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) - # print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) - # print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)) - # print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) - # print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) - # print("Wan Block -- Cross Attn. attn_output, min: ", np.min(attn_output)) - # print("Wan Block -- Cross Attn. attn_output, max: ", np.max(attn_output)) hidden_states = hidden_states + attn_output - # print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) - # print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) # 3. Feed-forward norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) - # print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states)) - # print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states)) ff_output = self.ffn(norm_hidden_states) - # print("Wan Block -- ff_output, min: ", np.min(ff_output)) - # print("Wan Block -- ff_output, max: ", np.max(ff_output)) hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype) - # print("Wan Block -- hidden_states, min: ", np.min(hidden_states)) - # print("Wan Block -- hidden_states, max: ", np.max(hidden_states)) - # print("Wan Block -- COMPLETE -- ") return hidden_states @@ -488,45 +456,21 @@ def __call__( rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) - # print("***** After patch embedding") - # print("hidden_states, min: ", np.min(hidden_states)) - # print("hidden_states, max: ", np.max(hidden_states)) hidden_states = jax.lax.collapse(hidden_states, 1, -1) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - # print("***** After condition embedder") - # print("temb, min: ", np.min(temb)) - # print("temb, max: ", np.max(temb)) - # print("timestep_proj, min: ", np.min(timestep_proj)) - # print("timestep_proj, max: ", np.max(timestep_proj)) - # print("encoder_hidden_states min: ", np.min(encoder_hidden_states)) - # print("encoder_hidden_states max: ", np.max(encoder_hidden_states)) - timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") for block in self.blocks: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - # print("After block, hidden_states min:", np.min(hidden_states)) - # print("After block, hidden_states max:", np.max(hidden_states)) - # jax.debug.print("after block, hidden_states min: {x}", x=hidden_states.min()) - # jax.debug.print("after block, hidden_states min: {x}", x=hidden_states.max()) - #breakpoint() shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) - # print("shift min: ", shift.min()) - # print("shift.max: ", shift.max()) - # print("scale.min: ", scale.min()) - # print("scale.max: ", scale.max()) hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) - # print("hidden_states.min: ", hidden_states.min()) - # print("hidden_states.max: ", hidden_states.max()) hidden_states = self.proj_out(hidden_states) - # print("After proj_out -- hidden_states.min: ", hidden_states.min()) - # print("After proj_out -- hidden_states.max: ", hidden_states.max()) # TODO - can this reshape happen in a single command? hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index e56c3bae1..1f97ef478 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -26,6 +26,7 @@ from ...models.wan.transformers.transformer_wan import WanModel from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache from maxdiffusion.video_processor import VideoProcessor +from ...utils import export_to_video from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState from transformers import AutoTokenizer, UMT5EncoderModel import ftfy @@ -34,20 +35,35 @@ import torch def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text + text = whitespace_clean(basic_clean(text)) + return text + + +def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState: + vs.sharding_rules = logical_axis_rules + return vs + +partial(nnx.jit, static_argnums=(1,)) +def create_sharded_logical_model(model, logical_axis_rules): + graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) + p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules) + state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + pspecs = nnx.get_partition_spec(state) + sharded_state = jax.lax.with_sharding_constraint(state, pspecs) + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + return wan_transformer class WanPipeline: r""" @@ -94,22 +110,6 @@ def __init__( self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - @classmethod - def load_vae(cls, rngs: nnx.Rngs, config: HyperParameters): - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - - graphdef, state = nnx.split(wan_vae, nnx.Param) - params = state.to_pure_dict() - # This replaces random params with the model. - params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") - params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - wan_vae = nnx.merge(graphdef, params) - - return wan_vae, vae_cache - @classmethod def load_text_encoder(cls, config: HyperParameters): text_encoder = UMT5EncoderModel.from_pretrained( @@ -126,6 +126,31 @@ def load_tokenizer(cls, config: HyperParameters): ) return tokenizer + @classmethod + def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + + graphdef, state = nnx.split(wan_vae, nnx.Param) + params = state.to_pure_dict() + # This replaces random params with the model. + params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + params = jax.device_put(params, PositionalSharding(devices_array).replicate()) + wan_vae = nnx.merge(graphdef, params) + p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) + # Shard + with mesh: + wan_vae = p_create_sharded_logical_model(model=wan_vae) + return wan_vae, vae_cache + @classmethod def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): wan_transformer = WanModel.from_config( @@ -144,6 +169,10 @@ def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, c params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) params = jax.device_put(params, PositionalSharding(devices_array).replicate()) wan_transformer = nnx.merge(graphdef, params, rest_of_state) + # Shard + p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) + with mesh: + wan_transformer = p_create_sharded_logical_model(model=wan_transformer) return wan_transformer @classmethod @@ -154,18 +183,21 @@ def load_scheduler(cls, config): flow_shift=config.flow_shift # 5.0 for 720p, 3.0 for 480p ) return scheduler, scheduler_state - + @classmethod def from_pretrained(cls, config: HyperParameters): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - - wan_vae, vae_cache = cls.load_vae(rngs=rngs, config=config) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + scheduler, scheduler_state = cls.load_scheduler(config=config) return WanPipeline( @@ -221,25 +253,29 @@ def encode_prompt( negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 226, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) + + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) - prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) - negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) return prompt_embeds, negative_prompt_embeds def prepare_latents( @@ -275,7 +311,10 @@ def __call__( num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512 + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None ): if num_frames % self.vae_scale_factor_temporal != 1: max_logging.log( @@ -293,19 +332,22 @@ def __call__( prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds ) num_channel_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size=batch_size, - vae_scale_factor_temporal=self.vae_scale_factor_temporal, - vae_scale_factor_spatial=self.vae_scale_factor_spatial, - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channel_latents - ) + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + vae_scale_factor_temporal=self.vae_scale_factor_temporal, + vae_scale_factor_spatial=self.vae_scale_factor_spatial, + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channel_latents + ) prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) @@ -328,7 +370,7 @@ def __call__( scheduler_state=scheduler_state ) with self.mesh: - latent = p_run_inference( + latents = p_run_inference( graphdef=graphdef, sharded_state=state, rest_of_state=rest_of_state, @@ -336,9 +378,28 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds ) + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latents = latents / latents_std + latents_mean + + latents = latents.astype(self.config.weights_dtype) + + jitted_decode = jax.jit( + partial( + self.vae.decode, + feat_cache=self.vae_cache, + return_dict=False + ) + ) + with self.mesh: + video = jitted_decode(latents)[0] + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + video = self.video_processor.postprocess_video(video, output_type="np") + export_to_video(video[0], "jax_output.mp4", fps=24) -@partial(jax.jit, static_argnums=(6, 7, 8)) +#@partial(jax.jit, static_argnums=(6, 7, 8)) def run_inference( graphdef, sharded_state, @@ -352,7 +413,6 @@ def run_inference( scheduler_state): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) do_classifier_free_guidance = guidance_scale > 1.0 - for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents.shape[0]) @@ -373,5 +433,4 @@ def run_inference( )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 7d750c8bb..fe037c255 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -14,6 +14,7 @@ limitations under the License. """ +import os import functools import torch import torch.nn as nn @@ -21,6 +22,11 @@ import jax import jax.numpy as jnp from flax import nnx +from jax.sharding import Mesh +from .. import pyconfig +from ..max_utils import ( + create_device_mesh, +) import numpy as np import unittest from absl.testing import absltest @@ -41,6 +47,8 @@ from ..utils import load_video from ..video_processor import VideoProcessor +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + CACHE_T = 2 @@ -249,6 +257,17 @@ def test_wan_resample(self): def test_3d_conv(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 @@ -269,7 +288,8 @@ def test_3d_conv(self): out_channels=out_channels, kernel_size=(kernel_d, kernel_h, kernel_w), padding=(padding_d, padding_h, padding_w), - rngs=rngs, # Pass rngs for initialization + rngs=rngs, # Pass rngs for initialization, + mesh=mesh ) # --- Test Case 1: No Cache --- @@ -289,6 +309,16 @@ def test_3d_conv(self): def test_wan_residual(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) # --- Test Case 1: same in/out dim --- in_dim = out_dim = 96 batch = 1 @@ -303,6 +333,7 @@ def test_wan_residual(self): in_dim=in_dim, out_dim=out_dim, rngs=rngs, + mesh=mesh ) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) @@ -317,6 +348,7 @@ def test_wan_residual(self): in_dim=in_dim, out_dim=out_dim, rngs=rngs, + mesh=mesh ) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) @@ -339,13 +371,23 @@ def test_wan_attention(self): def test_wan_midblock(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) batch = 1 t = 1 dim = 384 height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock(dim=dim, rngs=rngs) + wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) output = wan_midblock(dummy_input) assert output.shape == input_shape @@ -353,6 +395,16 @@ def test_wan_midblock(self): def test_wan_decode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] @@ -367,6 +419,7 @@ def test_wan_decode(self): num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, + mesh=mesh ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 @@ -386,6 +439,16 @@ def test_wan_decode(self): def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] @@ -400,6 +463,7 @@ def test_wan_encode(self): num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, + mesh=mesh ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 @@ -418,10 +482,25 @@ def vae_encode(video, wan_vae, vae_cache, key): latent = latent.latent_dist.sample(key) return latent - pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" key = jax.random.key(0) rngs = nnx.Rngs(key) - wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh + ) vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) @@ -435,7 +514,7 @@ def vae_encode(video, wan_vae, vae_cache, key): graphdef, state = nnx.split(wan_vae) params = state.to_pure_dict() # This replaces random params with the model. - params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) wan_vae = nnx.merge(graphdef, params) From b7c8ba679f5e1f27875e63eefd6399f0922e0014 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 29 May 2025 17:15:35 +0000 Subject: [PATCH 39/54] wan pipeline with generation. Correctness is still not verified. --- src/maxdiffusion/generate_wan.py | 19 ++- .../pipelines/wan/wan_pipeline.py | 158 +++++++++--------- 2 files changed, 97 insertions(+), 80 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 3ee5d6c3f..3ba04e0ca 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -13,14 +13,17 @@ # limitations under the License. from typing import Sequence +import time from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from maxdiffusion import pyconfig from absl import app +from maxdiffusion.utils import export_to_video def run(config): pipeline = WanPipeline.from_pretrained(config) - pipeline( + s0 = time.perf_counter() + video = pipeline( prompt=config.prompt, negative_prompt=config.negative_prompt, height=config.height, @@ -29,6 +32,20 @@ def run(config): num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, ) + print("compile time: ", (time.perf_counter() - s0)) + s0 = time.perf_counter() + video = pipeline( + prompt=config.prompt, + negative_prompt=config.negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + print("generation time: ", (time.perf_counter() - s0)) + export_to_video(video[0], "jax_output.mp4", fps=16) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 1f97ef478..c57f8f90f 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -26,7 +26,6 @@ from ...models.wan.transformers.transformer_wan import WanModel from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache from maxdiffusion.video_processor import VideoProcessor -from ...utils import export_to_video from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState from transformers import AutoTokenizer, UMT5EncoderModel import ftfy @@ -314,75 +313,77 @@ def __call__( max_sequence_length: int = 512, latents: jax.Array = None, prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False ): - if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + if not vae_only: + if num_frames % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds - ) - num_channel_latents = self.transformer.config.in_channels - if latents is None: - latents = self.prepare_latents( - batch_size=batch_size, - vae_scale_factor_temporal=self.vae_scale_factor_temporal, - vae_scale_factor_spatial=self.vae_scale_factor_spatial, - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channel_latents + num_channel_latents = self.transformer.config.in_channels + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + vae_scale_factor_temporal=self.vae_scale_factor_temporal, + vae_scale_factor_spatial=self.vae_scale_factor_spatial, + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channel_latents + ) + + prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) + negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) + + latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate()) + prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate()) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate()) + + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) - prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) - negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) - - latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate()) - prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate()) - negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate()) - - scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=self.config.num_inference_steps, shape=latents.shape - ) - - graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) - p_run_inference = partial( - run_inference, - guidance_scale=self.config.guidance_scale, - num_inference_steps=self.config.num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state - ) - with self.mesh: - latents = p_run_inference( - graphdef=graphdef, - sharded_state=state, - rest_of_state=rest_of_state, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds + p_run_inference = partial( + run_inference, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) - latents = latents / latents_std + latents_mean - - latents = latents.astype(self.config.weights_dtype) + with self.mesh: + latents = p_run_inference( + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds + ) + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latents = latents / latents_std + latents_mean + + latents = latents.astype(self.config.weights_dtype) jitted_decode = jax.jit( partial( @@ -396,9 +397,18 @@ def __call__( video = jnp.transpose(video, (0, 4, 1, 2, 3)) video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) video = self.video_processor.postprocess_video(video, output_type="np") - export_to_video(video[0], "jax_output.mp4", fps=24) + return video + + +@jax.jit +def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds): + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + return wan_transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=prompt_embeds + )[0] - #@partial(jax.jit, static_argnums=(6, 7, 8)) def run_inference( graphdef, @@ -411,26 +421,16 @@ def run_inference( num_inference_steps: int, scheduler : FlaxUniPCMultistepScheduler, scheduler_state): - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) do_classifier_free_guidance = guidance_scale > 1.0 for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents.shape[0]) - - noise_pred = wan_transformer( - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - return_dict=False - )[0] + + noise_pred = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds) if do_classifier_free_guidance: - noise_uncond = wan_transformer( - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - return_dict=False - )[0] + noise_uncond = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, negative_prompt_embeds) noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents \ No newline at end of file From 238890874ebc49eb6b232059095beec49ced24c0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 30 May 2025 18:20:14 +0000 Subject: [PATCH 40/54] use collapse instead of reshape for final activation. --- .../wan/transformers/transformer_wan.py | 11 +++-- .../pipelines/wan/wan_pipeline.py | 49 +++++++++++-------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 8727242d5..271192c44 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -447,13 +447,14 @@ def __call__( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[jax.Array, Dict[str, jax.Array]]: - batch_size, num_frames, height, width, num_channels = hidden_states.shape + batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w + hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) @@ -472,9 +473,9 @@ def __call__( hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) hidden_states = self.proj_out(hidden_states) - # TODO - can this reshape happen in a single command? hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) - hidden_states = hidden_states.reshape(batch_size, num_frames, height, width, num_channels) - # jax.debug.print("FINAL hidden_states min: {x}", x=hidden_states.min()) - # jax.debug.print("FINAL hidden_states max: {x}", x=hidden_states.max()) + hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) + hidden_states = jax.lax.collapse(hidden_states, 6, None) + hidden_states = jax.lax.collapse(hidden_states, 4, 6) + hidden_states = jax.lax.collapse(hidden_states, 2, 4) return hidden_states \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index c57f8f90f..51425c0e5 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -109,6 +109,16 @@ def __init__( self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.jitted_decode = jax.jit( + partial( + self.vae.decode, + feat_cache=self.vae_cache, + return_dict=False + ) + ) + + self.p_run_inference = None + @classmethod def load_text_encoder(cls, config: HyperParameters): text_encoder = UMT5EncoderModel.from_pretrained( @@ -184,20 +194,27 @@ def load_scheduler(cls, config): return scheduler, scheduler_state @classmethod - def from_pretrained(cls, config: HyperParameters): + def from_pretrained(cls, config: HyperParameters, vae_only=False): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) + transformer=None + tokenizer=None + scheduler=None + scheduler_state=None + text_encoder=None + if not vae_only: + with mesh: + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) + + scheduler, scheduler_state = cls.load_scheduler(config=config) with mesh: wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) return WanPipeline( tokenizer=tokenizer, @@ -291,10 +308,10 @@ def prepare_latents( num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 shape = ( batch_size, + num_channels_latents, num_latent_frames, int(height) // vae_scale_factor_spatial, int(width) // vae_scale_factor_spatial, - num_channels_latents ) latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype) @@ -370,6 +387,7 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state ) + with self.mesh: latents = p_run_inference( graphdef=graphdef, @@ -379,21 +397,13 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) latents = latents / latents_std + latents_mean - latents = latents.astype(self.config.weights_dtype) - jitted_decode = jax.jit( - partial( - self.vae.decode, - feat_cache=self.vae_cache, - return_dict=False - ) - ) with self.mesh: - video = jitted_decode(latents)[0] + video = self.jitted_decode(latents)[0] video = jnp.transpose(video, (0, 4, 1, 2, 3)) video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) video = self.video_processor.postprocess_video(video, output_type="np") @@ -431,6 +441,5 @@ def run_inference( if do_classifier_free_guidance: noise_uncond = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, negative_prompt_embeds) noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents \ No newline at end of file From 5cc2e495ba5035062c1a159bcf5e7b46684a8804 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 30 May 2025 23:00:16 +0000 Subject: [PATCH 41/54] implements a working wan 2.1 pipeline. --- src/maxdiffusion/generate_wan.py | 24 +++++++++++-------- .../pipelines/wan/wan_pipeline.py | 12 ++-------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 3ba04e0ca..ed4a02c7d 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Sequence +import jax import time from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from maxdiffusion import pyconfig @@ -21,7 +22,6 @@ def run(config): pipeline = WanPipeline.from_pretrained(config) - s0 = time.perf_counter() video = pipeline( prompt=config.prompt, @@ -32,17 +32,20 @@ def run(config): num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, ) + print("compile time: ", (time.perf_counter() - s0)) + export_to_video(video[0], "jax_output.mp4", fps=16) s0 = time.perf_counter() - video = pipeline( - prompt=config.prompt, - negative_prompt=config.negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, - ) + with jax.profiler.trace("/tmp/trace/"): + video = pipeline( + prompt=config.prompt, + negative_prompt=config.negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) print("generation time: ", (time.perf_counter() - s0)) export_to_video(video[0], "jax_output.mp4", fps=16) @@ -51,5 +54,6 @@ def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) + if __name__ == "__main__": app.run(main) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 51425c0e5..b4f4f9dc1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -109,14 +109,6 @@ def __init__( self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - self.jitted_decode = jax.jit( - partial( - self.vae.decode, - feat_cache=self.vae_cache, - return_dict=False - ) - ) - self.p_run_inference = None @classmethod @@ -402,8 +394,8 @@ def __call__( latents = latents / latents_std + latents_mean latents = latents.astype(self.config.weights_dtype) - with self.mesh: - video = self.jitted_decode(latents)[0] + video = self.vae.decode(latents, self.vae_cache)[0] + video = jnp.transpose(video, (0, 4, 1, 2, 3)) video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) video = self.video_processor.postprocess_video(video, output_type="np") From 5f2434da60ed0083ce3b33eaf7a71e8918b9e711 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 3 Jun 2025 00:16:50 +0000 Subject: [PATCH 42/54] fix attention bug for lower frames. --- src/maxdiffusion/models/attention_flax.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f87627a87..be91a94ca 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -75,11 +75,16 @@ def _reshape_batch_dim_to_heads(tensor, heads): return tensor def _reshape_heads_to_batch_dim(tensor, heads): - batch_size, seq_len, dim = tensor.shape - head_size = heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + head_size = heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + else: + batch_size, head_size, seq_len, head_dim = tensor.shape + tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) + return tensor def _reshape_heads_to_head_dim(tensor): From d64e5219dc314070fd95bd3ca7a36a1527892a8d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 4 Jun 2025 00:29:45 +0000 Subject: [PATCH 43/54] reduces memory significantly when loading transformer. Needs clean up. --- .../wan/transformers/transformer_wan.py | 73 +++++++++------- .../pipelines/wan/wan_pipeline.py | 85 ++++++++++++++----- 2 files changed, 108 insertions(+), 50 deletions(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 271192c44..352601843 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -34,6 +34,38 @@ BlockSizes = common_types.BlockSizes +def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + freqs_dtype=jnp.float64, + use_real=False + ) + freqs.append(freq) + freqs = jnp.concatenate(freqs, axis=1) + # sizes = jnp.array([ + # attention_head_dim // 2 - 2 * (attention_head_dim // 6), + # attention_head_dim // 6, + # attention_head_dim // 6, + # ]) + # cumulative_sizes = jnp.cumsum(jnp.array(sizes)) + # split_indices = cumulative_sizes[:-1] + t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6) + hw_size = attention_head_dim // 6 + + dims = [t_size, hw_size, hw_size] + + # Calculate split indices as a static list of integers + cumulative_sizes = np.cumsum(dims) + split_indices = cumulative_sizes[:-1].tolist() + freqs_split = jnp.split(freqs, split_indices, axis=1) + return freqs_split + class WanRotaryPosEmbed(nnx.Module): def __init__( self, @@ -45,44 +77,23 @@ def __init__( self.attention_head_dim = attention_head_dim self.patch_size = patch_size self.max_seq_len = max_seq_len - - h_dim = w_dim = 2 * (attention_head_dim // 6) - t_dim = attention_head_dim - h_dim - w_dim - - freqs = [] - for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed( - dim, - self.max_seq_len, - theta, - freqs_dtype=jnp.float64, - use_real=False - ) - freqs.append(freq) - freqs = jnp.concatenate(freqs, axis=1) - - sizes = [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6, - ] - cumulative_sizes = jnp.cumsum(jnp.array(sizes)) - split_indices = cumulative_sizes[:-1] - self.freqs_split = jnp.split(freqs, split_indices, axis=1) + self.theta = theta def __call__(self, hidden_states: jax.Array) -> jax.Array: _, num_frames, height, width, _ = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - freqs_f = jnp.expand_dims(jnp.expand_dims(self.freqs_split[0][:ppf], axis=1), axis=1) - freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, self.freqs_split[0].shape[-1])) + freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim) + + freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) + freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) - freqs_h = jnp.expand_dims(jnp.expand_dims(self.freqs_split[1][:pph], axis=0), axis=2) - freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, self.freqs_split[1].shape[-1])) + freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) + freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) - freqs_w = jnp.expand_dims(jnp.expand_dims(self.freqs_split[2][:ppw], axis=0), axis=1) - freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, self.freqs_split[2].shape[-1])) + freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) + freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) @@ -362,7 +373,7 @@ def __init__( qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, image_dim: Optional[int] = None, - added_kn_proj_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, pos_embed_seq_len: Optional[int] = None, flash_min_seq_length: int = 4096, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index b4f4f9dc1..f8272b866 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -18,6 +18,8 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, PositionalSharding +import flax +import flax.linen as nn from flax import nnx from ...pyconfig import HyperParameters from ... import max_logging @@ -54,6 +56,48 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl vs.sharding_rules = logical_axis_rules return vs + +partial(nnx.jit, static_argnums=(3,)) +def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + # breakpoint() + def create_model(rngs: nnx.Rngs, wan_config: dict): + wan_transformer = WanModel(**wan_config, rngs=rngs) + return wan_transformer + + wan_config = WanModel.load_config( + config.pretrained_model_name_or_path, + subfolder="transformer" + ) + wan_config["mesh"] = mesh + wan_config["dtype"] = config.activations_dtype + wan_config["weights_dtype"] = config.weights_dtype + wan_config["attention"] = config.attention + p_model_factory = partial(create_model, wan_config=wan_config) + wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) + graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) + #breakpoint() + logical_state_spec = nnx.get_partition_spec(state) + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) + params = state.to_pure_dict() + state = dict(nnx.to_flat_state(state)) + # del state + params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + for path, val in flax.traverse_util.flatten_dict(params).items(): + sharding = logical_state_sharding[path].value + state[path].value = jax.device_put(val, sharding) + state = nnx.from_flat_state(state) + p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=config.logical_axis_rules) + state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + pspecs = nnx.get_partition_spec(state) + #breakpoint() + sharded_state = jax.lax.with_sharding_constraint(state, pspecs) + #breakpoint() + #wan_transformer = jax.jit(nnx.merge(graphdef, sharded_state, rest_of_state), in_shardings=None, out_shardings=sharded_state) + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + return wan_transformer + partial(nnx.jit, static_argnums=(1,)) def create_sharded_logical_model(model, logical_axis_rules): graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) @@ -154,26 +198,29 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H @classmethod def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - wan_transformer = WanModel.from_config( - config.pretrained_model_name_or_path, - subfolder="transformer", - rngs=rngs, - attention=config.attention, - mesh=mesh, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype - ) - graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) - params = state.to_pure_dict() - del state - params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") - params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - params = jax.device_put(params, PositionalSharding(devices_array).replicate()) - wan_transformer = nnx.merge(graphdef, params, rest_of_state) - # Shard - p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) with mesh: - wan_transformer = p_create_sharded_logical_model(model=wan_transformer) + wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + # wan_transformer = WanModel.from_config( + # config.pretrained_model_name_or_path, + # subfolder="transformer", + # rngs=rngs, + # attention=config.attention, + # mesh=mesh, + # dtype=config.activations_dtype, + # weights_dtype=config.weights_dtype + # ) + # graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) + # breakpoint() + # params = state.to_pure_dict() + # del state + # #params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") + # params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + # #params = jax.device_put(params, PositionalSharding(devices_array).replicate()) + # wan_transformer = nnx.merge(graphdef, params, rest_of_state) + # # Shard + # p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) + # with mesh: + # wan_transformer = p_create_sharded_logical_model(model=wan_transformer) return wan_transformer @classmethod From 56f5225768e5c66b387b79655673617a0219b3e1 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 4 Jun 2025 19:54:35 +0000 Subject: [PATCH 44/54] support bs > 1. Issue where all gens except for 1st coming out bad. --- src/maxdiffusion/configs/base_wan_14b.yml | 16 ++-- src/maxdiffusion/generate_wan.py | 10 ++- .../models/wan/autoencoder_kl_wan.py | 15 ++-- .../wan/transformers/transformer_wan.py | 2 +- .../pipelines/wan/wan_pipeline.py | 86 +++++++++---------- 5 files changed, 68 insertions(+), 61 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1f5920a7b..1aee06894 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -53,14 +53,14 @@ split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: { - "block_q" : 512, - "block_kv_compute" : 512, - "block_kv" : 512, - "block_q_dkv" : 512, - "block_kv_dkv" : 512, - "block_kv_dkv_compute" : 512, - "block_q_dq" : 512, - "block_kv_dq" : 512 + "block_q" : 1024, + "block_kv_compute" : 1024, + "block_kv" : 1024, + "block_q_dkv" : 1024, + "block_kv_dkv" : 1024, + "block_kv_dkv_compute" : 1024, + "block_q_dq" : 1024, + "block_kv_dq" : 1024 } # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index ed4a02c7d..0f2cefa79 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -23,7 +23,7 @@ def run(config): pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() - video = pipeline( + videos = pipeline( prompt=config.prompt, negative_prompt=config.negative_prompt, height=config.height, @@ -34,10 +34,11 @@ def run(config): ) print("compile time: ", (time.perf_counter() - s0)) - export_to_video(video[0], "jax_output.mp4", fps=16) + for i in range(len(videos)): + export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16) s0 = time.perf_counter() with jax.profiler.trace("/tmp/trace/"): - video = pipeline( + videos = pipeline( prompt=config.prompt, negative_prompt=config.negative_prompt, height=config.height, @@ -47,7 +48,8 @@ def run(config): guidance_scale=config.guidance_scale, ) print("generation time: ", (time.perf_counter() - s0)) - export_to_video(video[0], "jax_output.mp4", fps=16) + for i in range(len(videos)): + export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16) def main(argv: Sequence[str]) -> None: diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index c80f88b01..cf280fdff 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -1131,12 +1131,17 @@ def _decode( # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. # Most likely due to an incorrect reshaping in the decoder. fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] - if len(fm1.shape) == 4: - fm1 = jnp.expand_dims(fm1, axis=0) - fm2 = jnp.expand_dims(fm2, axis=0) - fm3 = jnp.expand_dims(fm3, axis=0) - fm4 = jnp.expand_dims(fm4, axis=0) + # When batch_size is 0, expand batch dim for contatenation + # else, expand frame dim for concatenation so that batch dim stays intact. + axis=0 + if fm1.shape[0] > 1: + axis=1 + if len(fm1.shape) == 4: + fm1 = jnp.expand_dims(fm1, axis=axis) + fm2 = jnp.expand_dims(fm2, axis=axis) + fm3 = jnp.expand_dims(fm3, axis=axis) + fm4 = jnp.expand_dims(fm4, axis=axis) out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) out = jnp.clip(out, min=-1.0, max=1.0) feat_cache.clear_cache() diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 352601843..a3cfd7c60 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -398,7 +398,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, None, "conv_out",)), ) # 2. Condition embeddings diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index f8272b866..3bf98ddc4 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -17,13 +17,14 @@ import numpy as np import jax import jax.numpy as jnp -from jax.sharding import Mesh, PositionalSharding +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P import flax import flax.linen as nn from flax import nnx from ...pyconfig import HyperParameters from ... import max_logging from ... import max_utils +from ...max_utils import get_flash_block_sizes, get_precision from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae from ...models.wan.transformers.transformer_wan import WanModel from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache @@ -59,11 +60,12 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl partial(nnx.jit, static_argnums=(3,)) def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - # breakpoint() + def create_model(rngs: nnx.Rngs, wan_config: dict): wan_transformer = WanModel(**wan_config, rngs=rngs) return wan_transformer + # 1. Load config. wan_config = WanModel.load_config( config.pretrained_model_name_or_path, subfolder="transformer" @@ -72,32 +74,39 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype wan_config["attention"] = config.attention + wan_config["precision"] = get_precision(config) + wan_config["flash_block_sizes"] = get_flash_block_sizes(config) + + # 2. eval_shape - will not use flops or create weights on device + # thus not using HBM memory. p_model_factory = partial(create_model, wan_config=wan_config) wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) - #breakpoint() + + # 3. retrieve the state shardings, mapping logical names to mesh axis names. logical_state_spec = nnx.get_partition_spec(state) logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) params = state.to_pure_dict() state = dict(nnx.to_flat_state(state)) - # del state + + # 4. Load pretrained weights and move them to device using the state shardings from (3) above. + # This helps with loading sharded weights directly into the accelerators without fist copying them + # all to one device and then distributing them, thus using low HBM memory. params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value - state[path].value = jax.device_put(val, sharding) + try: + state[path].value = jax.device_put(val, sharding) + except: + breakpoint() state = nnx.from_flat_state(state) - p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=config.logical_axis_rules) - state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) - pspecs = nnx.get_partition_spec(state) - #breakpoint() - sharded_state = jax.lax.with_sharding_constraint(state, pspecs) - #breakpoint() - #wan_transformer = jax.jit(nnx.merge(graphdef, sharded_state, rest_of_state), in_shardings=None, out_shardings=sharded_state) - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + + wan_transformer = nnx.merge(graphdef, state, rest_of_state) return wan_transformer + partial(nnx.jit, static_argnums=(1,)) def create_sharded_logical_model(model, logical_axis_rules): graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) @@ -108,6 +117,7 @@ def create_sharded_logical_model(model, logical_axis_rules): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) return wan_transformer + class WanPipeline: r""" Pipeline for text-to-video generation using Wan. @@ -155,6 +165,7 @@ def __init__( self.p_run_inference = None + @classmethod def load_text_encoder(cls, config: HyperParameters): text_encoder = UMT5EncoderModel.from_pretrained( @@ -163,6 +174,7 @@ def load_text_encoder(cls, config: HyperParameters): ) return text_encoder + @classmethod def load_tokenizer(cls, config: HyperParameters): tokenizer = AutoTokenizer.from_pretrained( @@ -171,6 +183,7 @@ def load_tokenizer(cls, config: HyperParameters): ) return tokenizer + @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( @@ -196,33 +209,14 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H wan_vae = p_create_sharded_logical_model(model=wan_vae) return wan_vae, vae_cache + @classmethod def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): with mesh: wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - # wan_transformer = WanModel.from_config( - # config.pretrained_model_name_or_path, - # subfolder="transformer", - # rngs=rngs, - # attention=config.attention, - # mesh=mesh, - # dtype=config.activations_dtype, - # weights_dtype=config.weights_dtype - # ) - # graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) - # breakpoint() - # params = state.to_pure_dict() - # del state - # #params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") - # params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - # #params = jax.device_put(params, PositionalSharding(devices_array).replicate()) - # wan_transformer = nnx.merge(graphdef, params, rest_of_state) - # # Shard - # p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) - # with mesh: - # wan_transformer = p_create_sharded_logical_model(model=wan_transformer) return wan_transformer + @classmethod def load_scheduler(cls, config): scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( @@ -232,6 +226,7 @@ def load_scheduler(cls, config): ) return scheduler, scheduler_state + @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False): devices_array = max_utils.create_device_mesh(config) @@ -268,6 +263,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False): config=config ) + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -302,6 +298,7 @@ def _get_t5_prompt_embeds( return prompt_embeds + def encode_prompt( self, prompt: Union[str, List[str]], @@ -333,6 +330,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def prepare_latents( self, batch_size: int, @@ -356,6 +354,7 @@ def prepare_latents( return latents + def __call__( self, prompt: Union[str, List[str]] = None, @@ -382,9 +381,9 @@ def __call__( # 2. Define call parameters if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + prompt = [prompt] + + batch_size = len(prompt) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, @@ -406,12 +405,13 @@ def __call__( num_channels_latents=num_channel_latents ) - prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) - negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype) - - latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate()) - prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate()) - negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate()) + data_sharding = PositionalSharding(self.devices_array).replicate() + if len(prompt) % jax.device_count() == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + + latents = jax.device_put(latents, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) scheduler_state = self.scheduler.set_timesteps( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape From 9ee7fd30226a20031214c17f8631ef25ab623fa8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 6 Jun 2025 00:12:26 +0000 Subject: [PATCH 45/54] improves performance by 14% on v5p. --- src/maxdiffusion/configs/base_wan_14b.yml | 12 +----- src/maxdiffusion/generate_wan.py | 4 +- src/maxdiffusion/models/attention_flax.py | 37 ++++++++++++------- .../wan/transformers/transformer_wan.py | 13 +++---- 4 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1aee06894..6968dd5df 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -52,16 +52,7 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te -flash_block_sizes: { - "block_q" : 1024, - "block_kv_compute" : 1024, - "block_kv" : 1024, - "block_q_dkv" : 1024, - "block_kv_dkv" : 1024, - "block_kv_dkv_compute" : 1024, - "block_q_dq" : 1024, - "block_kv_dq" : 1024 -} +flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 @@ -127,6 +118,7 @@ logical_axis_rules: [ ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 0f2cefa79..01975937b 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -35,7 +35,7 @@ def run(config): print("compile time: ", (time.perf_counter() - s0)) for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16) + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16) s0 = time.perf_counter() with jax.profiler.trace("/tmp/trace/"): videos = pipeline( @@ -49,7 +49,7 @@ def run(config): ) print("generation time: ", (time.perf_counter() - s0)) for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16) + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16) def main(argv: Sequence[str]) -> None: diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index be91a94ca..9028c47b1 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -19,6 +19,7 @@ import flax.linen as nn from flax import nnx import jax +from jax.sharding import PartitionSpec import jax.numpy as jnp from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask @@ -139,21 +140,23 @@ def _tpu_flash_attention( heads: int, mesh: Mesh, flash_axis_names: AxisNames, - flash_block_sizes: BlockSizes) -> jax.Array: + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32) -> jax.Array: """TPU Flash Attention""" + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 if flash_block_sizes: block_sizes = flash_block_sizes else: block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(512, query.shape[2]), - block_kv_compute=min(512, key.shape[2]), - block_kv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_kv_dkv=min(512, key.shape[2]), - block_kv_dkv_compute=min(512, query.shape[2]), - block_q_dq=min(512, query.shape[2]), - block_kv_dq=min(512, query.shape[2]), + block_q=min(max_block_size, query.shape[2]), + block_kv_compute=min(max_block_size, key.shape[2]), + block_kv=min(max_block_size, key.shape[2]), + block_q_dkv=min(max_block_size, query.shape[2]), + block_kv_dkv=min(max_block_size, key.shape[2]), + block_kv_dkv_compute=min(max_block_size, query.shape[2]), + block_q_dq=min(max_block_size, query.shape[2]), + block_kv_dq=min(max_block_size, query.shape[2]), ) query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q) @@ -340,7 +343,7 @@ def _apply_attention( if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: return _apply_attention_dot(query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention) elif attention_kernel == "flash": - return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes) + return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) else: @@ -668,7 +671,7 @@ def __init__( rngs=rngs, epsilon=eps, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("heads", )), + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm", )), param_dtype=weights_dtype ) @@ -676,7 +679,7 @@ def __init__( num_features=self.inner_dim, rngs=rngs, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("heads", )), + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm", )), param_dtype=weights_dtype ) @@ -702,9 +705,12 @@ def __call__( encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None ) -> jax.Array: + hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec('data', 'fsdp','tensor')) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec('data', 'fsdp','tensor')) dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states + query_proj = self.query(hidden_states) key_proj = self.key(encoder_hidden_states) value_proj = self.value(encoder_hidden_states) @@ -717,8 +723,13 @@ def __call__( key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - + query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec('data', 'tensor', None, None)) + key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec('data', 'tensor', None, None)) + value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec('data', 'tensor', None, None)) + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec('data', None, None)) + attn_output = attn_output.astype(dtype=dtype) hidden_states = self.proj_attn(attn_output) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index a3cfd7c60..566290299 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -179,8 +179,6 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -231,7 +229,6 @@ def __init__( param_dtype=weights_dtype, precision=precision, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) def __call__(self, hidden_states: jax.Array) -> jax.Array: @@ -338,7 +335,7 @@ def __call__( # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + attn_output = self.attn1(hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb) hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention @@ -443,11 +440,13 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) key = rngs.params() - self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5) + self.scale_shift_table = nnx.Param( + jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")) + ) def __call__( self, From b84fc343d5151fd0c67279ca1c09345821a9b6fe Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 6 Jun 2025 05:28:17 +0000 Subject: [PATCH 46/54] implements skip layer guidance for better generations. --- src/maxdiffusion/configs/base_wan_14b.yml | 11 +++- src/maxdiffusion/generate_wan.py | 13 +++++ .../wan/transformers/transformer_wan.py | 14 ++++- .../pipelines/wan/wan_pipeline.py | 57 ++++++++++++++++--- 4 files changed, 82 insertions(+), 13 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 6968dd5df..f5ac459ad 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -207,15 +207,20 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -height: 720 -width: 1280 +height: 480 +width: 832 num_frames: 81 guidance_scale: 5.0 +flow_shift: 3.0 + +# skip layer guidance +slg_layers: [9] +slg_start: 0.2 +slg_end: 1.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 save_final_checkpoint: False -flow_shift: 5.0 # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 01975937b..9b2838132 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -21,8 +21,15 @@ from maxdiffusion.utils import export_to_video def run(config): + print("seed: ", config.seed) pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() + + # Skip layer guidance + slg_layers = config.slg_layers + slg_start = config.slg_start + slg_end = config.slg_end + videos = pipeline( prompt=config.prompt, negative_prompt=config.negative_prompt, @@ -31,6 +38,9 @@ def run(config): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end ) print("compile time: ", (time.perf_counter() - s0)) @@ -46,6 +56,9 @@ def run(config): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end ) print("generation time: ", (time.perf_counter() - s0)) for i in range(len(videos)): diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 566290299..a0312b5bf 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -14,7 +14,7 @@ limitations under the License. """ -from typing import Tuple, Optional, Dict, Union, Any +from typing import Tuple, Optional, Dict, Union, Any, List import math import jax import jax.numpy as jnp @@ -453,6 +453,8 @@ def __call__( hidden_states: jax.Array, timestep: jax.Array, encoder_hidden_states: jax.Array, + is_uncond: jax.Array, # jnp.bool_ scalar + slg_mask: jax.Array, # jnp.bool_ array of shape (num_blocks,) encoder_hidden_states_image: Optional[jax.Array] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -476,8 +478,14 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") - for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + for block_idx, block in enumerate(self.blocks): + should_skip_block = slg_mask[block_idx] & is_uncond + hidden_states = jax.lax.cond( + should_skip_block, + lambda hs: hs, # If true, pass through original hidden_states (skip block) + lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), + hidden_states + ) shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 3bf98ddc4..c5db151ce 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -369,7 +369,10 @@ def __call__( latents: jax.Array = None, prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, - vae_only: bool = False + vae_only: bool = False, + slg_layers: List[int] = None, + slg_start: float = 0.0, + slg_end: float = 1.0 ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: @@ -424,7 +427,11 @@ def __call__( guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, scheduler=self.scheduler, - scheduler_state=scheduler_state + scheduler_state=scheduler_state, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, + num_transformer_layers=self.transformer.config.num_layers ) with self.mesh: @@ -450,12 +457,22 @@ def __call__( @jax.jit -def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds): +def transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + is_uncond, + slg_mask): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) return wan_transformer( hidden_states=latents, timestep=timestep, - encoder_hidden_states=prompt_embeds + encoder_hidden_states=prompt_embeds, + is_uncond=is_uncond, + slg_mask=slg_mask )[0] #@partial(jax.jit, static_argnums=(6, 7, 8)) @@ -469,16 +486,42 @@ def run_inference( guidance_scale: float, num_inference_steps: int, scheduler : FlaxUniPCMultistepScheduler, - scheduler_state): + num_transformer_layers: int, + scheduler_state, + slg_layers: List[int] = None, + slg_start: float = 0.0, + slg_end: float = 1.0 + ): do_classifier_free_guidance = guidance_scale > 1.0 for step in range(num_inference_steps): + slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) + if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): + slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents.shape[0]) - noise_pred = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds) + noise_pred = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + is_uncond=jnp.array(False, dtype=jnp.bool_), + slg_mask=slg_mask + ) if do_classifier_free_guidance: - noise_uncond = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, negative_prompt_embeds) + noise_uncond = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + negative_prompt_embeds, + is_uncond=jnp.array(True, dtype=jnp.bool_), + slg_mask=slg_mask + ) noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents \ No newline at end of file From 05f05544ff41a29ed70570e7ee63a11d6a681a74 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 10 Jun 2025 20:43:31 +0000 Subject: [PATCH 47/54] initial commit for wan training --- .../checkpointing/wan_checkpointer.py | 71 +++++++++ src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/models/attention_flax.py | 1 + src/maxdiffusion/pipelines/wan/__init__.py | 17 ++ .../pipelines/wan/wan_pipeline.py | 2 +- src/maxdiffusion/train_wan.py | 37 +++++ src/maxdiffusion/trainers/wan_trainer.py | 148 ++++++++++++++++++ 7 files changed, 276 insertions(+), 1 deletion(-) create mode 100644 src/maxdiffusion/checkpointing/wan_checkpointer.py create mode 100644 src/maxdiffusion/pipelines/wan/__init__.py create mode 100644 src/maxdiffusion/train_wan.py create mode 100644 src/maxdiffusion/trainers/wan_trainer.py diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py new file mode 100644 index 000000000..fefcb5371 --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -0,0 +1,71 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from abc import ABC +from flax import nnx +from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) +from ..pipelines.wan.wan_pipeline import WanPipeline +from .. import max_logging, max_utils + +WAN_CHECKPOINT = "WAN_CHECKPOINT" + +class WanCheckpointer(ABC): + def __init__(self, config, checkpoint_type): + self.config = config + self.checkpoint_type = checkpoint_type + + self.checkpoint_manager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type + ) + + # @nnx.jit + def _create_optimizer(self, model, config, learning_rate): + learning_rate_scheduler = max_utils.create_learning_rate_schedule( + learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps + ) + tx = max_utils.create_optimizer(config, learning_rate_scheduler) + # tx = nnx.Optimizer(model, tx) + + # _, state, rest_of_state = nnx.split((model, tx), ...) + # nnx.update((model, tx), state, rest_of_state) + + + return nnx.Optimizer(model, tx), learning_rate_scheduler + + def load_wan_configs_from_orbax(self, step): + max_logging.log("Restoring stable diffusion configs") + if step is None: + step = self.checkpoint_manager.latest_step() + if step is None: + return None + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None): + model_configs = self.load_wan_configs_from_orbax(step) + + if model_configs: + raise NotImplemented("model configs should not exist in orbax") + else: + pipeline = self.load_diffusers_checkpoint() + + return pipeline \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f5ac459ad..dda08817f 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -220,6 +220,7 @@ slg_end: 1.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 +fps: 24 save_final_checkpoint: False # SDXL Lightning parameters diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 9028c47b1..bce251dd0 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -524,6 +524,7 @@ class AttentionOp(nn.Module): quant: Quant = None def setup(self): + self.dpa_layer = None if self.attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py new file mode 100644 index 000000000..dab0ec292 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/__init__.py @@ -0,0 +1,17 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from .wan_pipeline import WanPipeline \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index c5db151ce..afe9b8bc5 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -100,7 +100,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): try: state[path].value = jax.device_put(val, sharding) except: - breakpoint() + raise ValueError("value should exist.") state = nnx.from_flat_state(state) wan_transformer = nnx.merge(graphdef, state, rest_of_state) diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py new file mode 100644 index 000000000..62d9ba859 --- /dev/null +++ b/src/maxdiffusion/train_wan.py @@ -0,0 +1,37 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from typing import Sequence + +import jax +from absl import app +from maxdiffusion import max_logging, pyconfig +from maxdiffusion.train_utils import validate_train_config + +def train(config): + from maxdiffusion.trainers.wan_trainer import WanTrainer + trainer = WanTrainer(config) + trainer.start_training() + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + config = pyconfig.config + validate_train_config(config) + max_logging.log(f"Found {jax.device_count()} devices.") + train(config) + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py new file mode 100644 index 000000000..1e54db989 --- /dev/null +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -0,0 +1,148 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import functools +import numpy as np +import jax.numpy as jnp +import jax +import jax.tree_util as jtu +from flax import nnx +from ..schedulers import FlaxEulerDiscreteScheduler +from .. import max_utils +from .. import max_logging +from ..checkpointing.wan_checkpointer import ( + WanCheckpointer, + WAN_CHECKPOINT +) +from multihost_dataloading import _form_global_array + +class WanTrainer(WanCheckpointer): + def __init__(self, config): + WanCheckpointer.__init__(self, config, WAN_CHECKPOINT) + if config.train_text_encoder: + raise ValueError("this script currently doesn't support training text_encoders") + + def post_training_steps(self, pipeline, params, train_states, msg=""): + pass + + def create_scheduler(self, pipeline, params): + # TODO - set right scheduler + noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 + ) + noise_scheduler_state = noise_scheduler.set_timesteps( + state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux" + ) + return noise_scheduler, noise_scheduler_state + + def calculate_tflops(self, pipeline): + pass + + def load_dataset(self, pipeline): + # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 + # Image pre-training - txt2img 256px + # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16 + # Image-video joint training - stage 2. 480px images and 480px 5 sec videos at fps=16 + # Image-video joint training - stage final. 720px images and 720px 5 sec videos at fps=16 + # prompt embeds shape: (1, 512, 4096) + # For now, we will pass the same latents over and over + # TODO - create a dataset + global_batch_size = self.config.per_device_batch_size * jax.device_count() + prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (global_batch_size, 512, 4096)) + latents = pipeline.prepare_latents( + global_batch_size, + vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, + vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_channels_latents=pipeline.transformer.config.in_channels + ) + return (latents, prompt_embeds) + + def start_training(self): + + pipeline = self.load_checkpoint() + mesh = pipeline.mesh + + optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, self.config.learning_rate) + + # @nnx.jit + # def create_transformer_state(transformer): + # optimizer = self._create_optimizer(transformer, self.config, self.config.learning_rate) + # breakpoint() + # _, state = nnx.split((transformer, optimizer)) + + # with mesh: + # create_transformer_state(pipeline.transformer) + + #graphdef, state = nnx.plit((pipeline.transformer, optimizer)) + dummy_inputs = self.load_dataset(pipeline) + dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs]) + + self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs) + + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): + + graphdef, state = nnx.split((pipeline.transformer, optimizer)) + state = state.to_pure_dict() + p_train_step = jax.jit( + train_step, + donate_argnums=(1,), + ) + rng = jax.random.key(self.config.seed) + start_step = 0 + for step in np.arange(start_step, self.config.max_train_steps): + with pipeline.mesh: + loss, state, rng = p_train_step(graphdef, state, data, rng) + max_logging.log(f"loss: {loss}") + +def train_step(graphdef, state, data, rng): + return step_optimizer(graphdef, state, data, rng) + +def step_optimizer(graphdef, state, data, rng): + _, new_rng = jax.random.split(rng) + def loss_fn(model): + latents, prompt_embeds = data + bsz = latents.shape[0] + timesteps = jnp.array([0] * bsz, dtype=jnp.int32) + + noise = jax.random.normal( + key=new_rng, + shape=latents.shape, + dtype=latents.dtype + ) + + # TODO - add noise here + + model_pred = model( + hidden_states=noise, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + is_uncond=jnp.array(False, dtype=jnp.bool_), + slg_mask=jnp.zeros(1, dtype=jnp.bool_) + ) + target = noise - latents + loss = (target - model_pred) ** 2 + loss = jnp.mean(loss) + #breakpoint() + return loss + model, optimizer = nnx.merge(graphdef, state) + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) + state = nnx.state((model, optimizer)) + state = state.to_pure_dict() + return loss, state, new_rng \ No newline at end of file From a60d2358c3ff857841c37d44434d20a9d8293c6b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 11 Jun 2025 18:26:41 +0000 Subject: [PATCH 48/54] working training pipeline on v5p at num_frames=1 --- src/maxdiffusion/trainers/wan_trainer.py | 64 ++++++++++++++++++++---- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 1e54db989..badc2330d 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -14,6 +14,8 @@ limitations under the License. """ +import os +import datetime import functools import numpy as np import jax.numpy as jnp @@ -21,8 +23,7 @@ import jax.tree_util as jtu from flax import nnx from ..schedulers import FlaxEulerDiscreteScheduler -from .. import max_utils -from .. import max_logging +from .. import max_utils, max_logging, train_utils from ..checkpointing.wan_checkpointer import ( WanCheckpointer, WAN_CHECKPOINT @@ -35,6 +36,8 @@ def __init__(self, config): if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") + self.global_batch_size = self.config.per_device_batch_size * jax.device_count() + def post_training_steps(self, pipeline, params, train_states, msg=""): pass @@ -49,7 +52,8 @@ def create_scheduler(self, pipeline, params): return noise_scheduler, noise_scheduler_state def calculate_tflops(self, pipeline): - pass + max_logging.log(f"WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") + return 0 def load_dataset(self, pipeline): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 @@ -60,10 +64,9 @@ def load_dataset(self, pipeline): # prompt embeds shape: (1, 512, 4096) # For now, we will pass the same latents over and over # TODO - create a dataset - global_batch_size = self.config.per_device_batch_size * jax.device_count() - prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (global_batch_size, 512, 4096)) + prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (self.global_batch_size, 512, 4096)) latents = pipeline.prepare_latents( - global_batch_size, + self.global_batch_size, vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, height=self.config.height, @@ -92,12 +95,24 @@ def start_training(self): #graphdef, state = nnx.plit((pipeline.transformer, optimizer)) dummy_inputs = self.load_dataset(pipeline) dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs]) - self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): graphdef, state = nnx.split((pipeline.transformer, optimizer)) + writer = max_utils.initialize_summary_writer(self.config) + num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) + max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) + max_utils.add_config_to_summary_writer(self.config, writer) + + if jax.process_index() == 0: + max_logging.log("***** Running training *****") + max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}") + max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") + + state = state.to_pure_dict() p_train_step = jax.jit( train_step, @@ -105,10 +120,36 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): ) rng = jax.random.key(self.config.seed) start_step = 0 + last_step_completion = datetime.datetime.now() + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None + running_gcs_metrics = [] if self.config.gcs_metrics else None + first_profiling_step = self.config.skip_first_n_steps_for_profiler + if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = np.clip( + first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 + ) + # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint. + start_step = 0 + per_device_tflops = self.calculate_tflops(pipeline) + for step in np.arange(start_step, self.config.max_train_steps): - with pipeline.mesh: - loss, state, rng = p_train_step(graphdef, state, data, rng) - max_logging.log(f"loss: {loss}") + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh: + state, train_metric, rng = p_train_step(graphdef, state, data, rng) + + new_time = datetime.datetime.now() + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + + train_utils.record_scalar_metrics( + train_metric, new_time - last_step_completion, per_device_tflops, learning_rate_scheduler(step) + ) + if self.config.write_metrics: + train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + last_step_completion = new_time def train_step(graphdef, state, data, rng): return step_optimizer(graphdef, state, data, rng) @@ -145,4 +186,5 @@ def loss_fn(model): optimizer.update(grads) state = nnx.state((model, optimizer)) state = state.to_pure_dict() - return loss, state, new_rng \ No newline at end of file + metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} + return state, metrics, new_rng \ No newline at end of file From b90584cbaa7708e8e137bc0b4af39259d7175322 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 12 Jun 2025 19:14:56 +0000 Subject: [PATCH 49/54] wan training for single frame + bug fixes. --- .../checkpointing/wan_checkpointer.py | 9 +--- src/maxdiffusion/generate_wan.py | 17 ++++---- src/maxdiffusion/maxdiffusion_utils.py | 26 ++++++++++++ .../pipelines/wan/wan_pipeline.py | 12 +++--- src/maxdiffusion/trainers/wan_trainer.py | 41 ++++--------------- 5 files changed, 52 insertions(+), 53 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index fefcb5371..5b08ebc10 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,6 +15,7 @@ """ from abc import ABC +import jax from flax import nnx from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline @@ -35,20 +36,14 @@ def __init__(self, config, checkpoint_type): dataset_type=config.dataset_type ) - # @nnx.jit def _create_optimizer(self, model, config, learning_rate): learning_rate_scheduler = max_utils.create_learning_rate_schedule( learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) tx = max_utils.create_optimizer(config, learning_rate_scheduler) - # tx = nnx.Optimizer(model, tx) - - # _, state, rest_of_state = nnx.split((model, tx), ...) - # nnx.update((model, tx), state, rest_of_state) - - return nnx.Optimizer(model, tx), learning_rate_scheduler + def load_wan_configs_from_orbax(self, step): max_logging.log("Restoring stable diffusion configs") if step is None: diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 9b2838132..9ce42befe 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -29,10 +29,13 @@ def run(config): slg_layers = config.slg_layers slg_start = config.slg_start slg_end = config.slg_end - + + prompt = [config.prompt] * jax.device_count() + negative_prompt= [config.negative_prompt] * jax.device_count() + videos = pipeline( - prompt=config.prompt, - negative_prompt=config.negative_prompt, + prompt=prompt, + negative_prompt=negative_prompt, height=config.height, width=config.width, num_frames=config.num_frames, @@ -45,12 +48,12 @@ def run(config): print("compile time: ", (time.perf_counter() - s0)) for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16) + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) s0 = time.perf_counter() with jax.profiler.trace("/tmp/trace/"): videos = pipeline( - prompt=config.prompt, - negative_prompt=config.negative_prompt, + prompt=prompt, + negative_prompt=negative_prompt, height=config.height, width=config.width, num_frames=config.num_frames, @@ -62,7 +65,7 @@ def run(config): ) print("generation time: ", (time.perf_counter() - s0)) for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16) + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) def main(argv: Sequence[str]) -> None: diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 43400a62e..de21d0763 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -286,6 +286,32 @@ def get_dummy_flux_inputs(config, pipeline, batch_size): return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) +def get_dummy_wan_inputs(config, pipeline, batch_size): + latents = pipeline.prepare_latents( + batch_size, + vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, + vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_channels_latents=pipeline.transformer.config.in_channels + ) + bsz = latents.shape[0] + prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096)) + timesteps = jnp.array([0] * bsz, dtype=jnp.int32) + return (latents, prompt_embeds, timesteps) + +def calculate_wan_tflops(config, pipeline, batch_size, rngs, train): + """ + Calculates jflux tflops. + batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't + cache the compilation when flash is enabled. + """ + (latents, prompt_embeds, timesteps) = get_dummy_wan_inputs(config, pipeline, batch_size) + return max_utils.calculate_model_tflops( + pipeline.transformer, + + ) def calculate_flux_tflops(config, pipeline, batch_size, rngs, train): """ diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index afe9b8bc5..3298f769b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -58,7 +58,7 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl return vs -partial(nnx.jit, static_argnums=(3,)) +# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): def create_model(rngs: nnx.Rngs, wan_config: dict): @@ -106,16 +106,15 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_transformer = nnx.merge(graphdef, state, rest_of_state) return wan_transformer - -partial(nnx.jit, static_argnums=(1,)) +@nnx.jit(static_argnums=(1,), donate_argnums=(0,)) def create_sharded_logical_model(model, logical_axis_rules): graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules) state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) pspecs = nnx.get_partition_spec(state) sharded_state = jax.lax.with_sharding_constraint(state, pspecs) - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - return wan_transformer + model = nnx.merge(graphdef, sharded_state, rest_of_state) + return model class WanPipeline: @@ -473,9 +472,8 @@ def transformer_forward_pass( encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask - )[0] + ) -#@partial(jax.jit, static_argnums=(6, 7, 8)) def run_inference( graphdef, sharded_state, diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index badc2330d..c60d59b97 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -23,7 +23,7 @@ import jax.tree_util as jtu from flax import nnx from ..schedulers import FlaxEulerDiscreteScheduler -from .. import max_utils, max_logging, train_utils +from .. import max_utils, max_logging, train_utils, maxdiffusion_utils from ..checkpointing.wan_checkpointer import ( WanCheckpointer, WAN_CHECKPOINT @@ -64,36 +64,15 @@ def load_dataset(self, pipeline): # prompt embeds shape: (1, 512, 4096) # For now, we will pass the same latents over and over # TODO - create a dataset - prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (self.global_batch_size, 512, 4096)) - latents = pipeline.prepare_latents( - self.global_batch_size, - vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, - vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, - height=self.config.height, - width=self.config.width, - num_frames=self.config.num_frames, - num_channels_latents=pipeline.transformer.config.in_channels - ) - return (latents, prompt_embeds) + return maxdiffusion_utils.get_dummy_wan_inputs(self.config, pipeline, self.global_batch_size) def start_training(self): pipeline = self.load_checkpoint() - mesh = pipeline.mesh - - optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, self.config.learning_rate) - - # @nnx.jit - # def create_transformer_state(transformer): - # optimizer = self._create_optimizer(transformer, self.config, self.config.learning_rate) - # breakpoint() - # _, state = nnx.split((transformer, optimizer)) - - # with mesh: - # create_transformer_state(pipeline.transformer) - - #graphdef, state = nnx.plit((pipeline.transformer, optimizer)) + del pipeline.vae dummy_inputs = self.load_dataset(pipeline) + mesh = pipeline.mesh + optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs]) self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs) @@ -116,7 +95,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): state = state.to_pure_dict() p_train_step = jax.jit( train_step, - donate_argnums=(1,), + donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) start_step = 0 @@ -137,7 +116,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh: - state, train_metric, rng = p_train_step(graphdef, state, data, rng) + state, train_metric, rng = p_train_step(state, graphdef, data, rng) new_time = datetime.datetime.now() @@ -151,15 +130,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) last_step_completion = new_time -def train_step(graphdef, state, data, rng): +def train_step(state, graphdef, data, rng): return step_optimizer(graphdef, state, data, rng) def step_optimizer(graphdef, state, data, rng): _, new_rng = jax.random.split(rng) def loss_fn(model): - latents, prompt_embeds = data - bsz = latents.shape[0] - timesteps = jnp.array([0] * bsz, dtype=jnp.int32) + latents, prompt_embeds, timesteps = data noise = jax.random.normal( key=new_rng, From 3bedc5debfdc5b7eda35229a6805de833a01b456 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 12 Jun 2025 23:04:14 +0000 Subject: [PATCH 50/54] lint. --- .../checkpointing/wan_checkpointer.py | 24 +- src/maxdiffusion/generate_wan.py | 47 +- src/maxdiffusion/maxdiffusion_utils.py | 20 +- src/maxdiffusion/models/attention_flax.py | 345 ++-- src/maxdiffusion/models/embeddings_flax.py | 172 +- .../models/modeling_flax_utils.py | 11 +- src/maxdiffusion/models/normalization_flax.py | 23 +- .../models/wan/autoencoder_kl_wan.py | 402 ++--- .../models/wan/transformers/__init__.py | 2 +- .../wan/transformers/transformer_wan.py | 437 +++-- src/maxdiffusion/models/wan/wan_utils.py | 21 +- src/maxdiffusion/pipelines/wan/__init__.py | 2 +- .../pipelines/wan/wan_pipeline.py | 366 ++-- .../scheduling_unipc_multistep_flax.py | 1545 ++++++++--------- .../tests/wan_transformer_test.py | 197 +-- src/maxdiffusion/tests/wan_vae_test.py | 89 +- src/maxdiffusion/train_wan.py | 6 +- src/maxdiffusion/trainers/wan_trainer.py | 50 +- 18 files changed, 1814 insertions(+), 1945 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 5b08ebc10..5f64d4880 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,7 +15,6 @@ """ from abc import ABC -import jax from flax import nnx from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline @@ -23,27 +22,28 @@ WAN_CHECKPOINT = "WAN_CHECKPOINT" + class WanCheckpointer(ABC): + def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type self.checkpoint_manager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, ) - + def _create_optimizer(self, model, config, learning_rate): learning_rate_scheduler = max_utils.create_learning_rate_schedule( - learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps + learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) tx = max_utils.create_optimizer(config, learning_rate_scheduler) return nnx.Optimizer(model, tx), learning_rate_scheduler - def load_wan_configs_from_orbax(self, step): max_logging.log("Restoring stable diffusion configs") if step is None: @@ -59,8 +59,8 @@ def load_checkpoint(self, step=None): model_configs = self.load_wan_configs_from_orbax(step) if model_configs: - raise NotImplemented("model configs should not exist in orbax") + raise NotImplementedError("model configs should not exist in orbax") else: pipeline = self.load_diffusers_checkpoint() - - return pipeline \ No newline at end of file + + return pipeline diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 9ce42befe..5791d8a8f 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -20,38 +20,21 @@ from absl import app from maxdiffusion.utils import export_to_video + def run(config): print("seed: ", config.seed) pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() - + # Skip layer guidance slg_layers = config.slg_layers slg_start = config.slg_start slg_end = config.slg_end prompt = [config.prompt] * jax.device_count() - negative_prompt= [config.negative_prompt] * jax.device_count() - - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end - ) + negative_prompt = [config.negative_prompt] * jax.device_count() - print("compile time: ", (time.perf_counter() - s0)) - for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) - s0 = time.perf_counter() - with jax.profiler.trace("/tmp/trace/"): - videos = pipeline( + videos = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=config.height, @@ -61,7 +44,25 @@ def run(config): guidance_scale=config.guidance_scale, slg_layers=slg_layers, slg_start=slg_start, - slg_end=slg_end + slg_end=slg_end, + ) + + print("compile time: ", (time.perf_counter() - s0)) + for i in range(len(videos)): + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + s0 = time.perf_counter() + with jax.profiler.trace("/tmp/trace/"): + videos = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, ) print("generation time: ", (time.perf_counter() - s0)) for i in range(len(videos)): @@ -74,4 +75,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) \ No newline at end of file + app.run(main) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index de21d0763..b9b1abdcb 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -286,21 +286,23 @@ def get_dummy_flux_inputs(config, pipeline, batch_size): return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) + def get_dummy_wan_inputs(config, pipeline, batch_size): latents = pipeline.prepare_latents( - batch_size, - vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, - vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_channels_latents=pipeline.transformer.config.in_channels + batch_size, + vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, + vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_channels_latents=pipeline.transformer.config.in_channels, ) bsz = latents.shape[0] prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096)) timesteps = jnp.array([0] * bsz, dtype=jnp.int32) return (latents, prompt_embeds, timesteps) + def calculate_wan_tflops(config, pipeline, batch_size, rngs, train): """ Calculates jflux tflops. @@ -309,10 +311,10 @@ def calculate_wan_tflops(config, pipeline, batch_size, rngs, train): """ (latents, prompt_embeds, timesteps) = get_dummy_wan_inputs(config, pipeline, batch_size) return max_utils.calculate_model_tflops( - pipeline.transformer, - + pipeline.transformer, ) + def calculate_flux_tflops(config, pipeline, batch_size, rngs, train): """ Calculates jflux tflops. diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index bce251dd0..006614f87 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,7 +15,6 @@ import functools import math from typing import Optional, Callable, Tuple -import numpy as np import flax.linen as nn from flax import nnx import jax @@ -48,6 +47,7 @@ def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() + def _check_attention_inputs(query: Array, key: Array, value: Array) -> None: """Check attention inputs.""" @@ -57,16 +57,19 @@ def _check_attention_inputs(query: Array, key: Array, value: Array) -> None: assert key.shape[-3] == value.shape[-3], "k, v lengths must match." assert query.shape[-1] == key.shape[-1], "q, k depths must match." + def _reshape_data_from_cudnn_flash(tensor): # reshapes from [b, s, h, d] back to [b, s, h * d] return tensor.reshape(tensor.shape[0], tensor.shape[1], -1) + def _reshape_data_for_cudnn_flash(tensor, heads): # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) batch, seq, heads_and_dim_head = tensor.shape tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) return tensor + def _reshape_batch_dim_to_heads(tensor, heads): batch_size, seq_len, dim = tensor.shape head_size = heads @@ -75,6 +78,7 @@ def _reshape_batch_dim_to_heads(tensor, heads): tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor + def _reshape_heads_to_batch_dim(tensor, heads): if tensor.ndim == 3: batch_size, seq_len, dim = tensor.shape @@ -85,9 +89,10 @@ def _reshape_heads_to_batch_dim(tensor, heads): else: batch_size, head_size, seq_len, head_dim = tensor.shape tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) - + return tensor + def _reshape_heads_to_head_dim(tensor): # takes a tensor of shape [b, h, s, d] and reshapes to [b, s, h * d] # This is used to transform the output of flash attention back into the format of other attention outputs @@ -95,6 +100,7 @@ def _reshape_heads_to_head_dim(tensor): tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) return jnp.reshape(tensor, (b, -1, h * d)) + def _unflatten_heads(tensor, heads): # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) batch, seq, heads_and_dim_head = tensor.shape @@ -103,13 +109,14 @@ def _unflatten_heads(tensor, heads): tensor = jnp.transpose(tensor, (0, 2, 1, 3)) return tensor + def _reshape_data_for_flash(tensor, heads, flash_block_size): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. """ - if tensor.ndim != 4: + if tensor.ndim != 4: tensor = _unflatten_heads(tensor, heads) - + # pad head_dim to 128 if less than that. kv_size = tensor.shape[-1] head_dim_pad = 0 @@ -126,22 +133,24 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size): mul = seq_len // flash_block_size # pad to the closest multiplier of flash_block_size seq_len_pad = (mul + 1) * flash_block_size - seq_len - + if kv_size < 128 or rem != 0: npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) tensor = jnp.pad(tensor, npad) return tensor, kv_size, seq_len + def _tpu_flash_attention( - query: jax.Array, - key: jax.Array, - value: jax.Array, - heads: int, - mesh: Mesh, - flash_axis_names: AxisNames, - flash_block_sizes: BlockSizes, - dtype: jnp.dtype = jnp.float32) -> jax.Array: + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + flash_axis_names: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, +) -> jax.Array: """TPU Flash Attention""" max_block_size = 1024 if dtype == jnp.bfloat16 else 512 @@ -198,17 +207,18 @@ def wrap_flash_attention(query, key, value): return x + def _apply_attention_dot( - query: Array, - key: Array, - value: Array, - dtype: jnp.dtype, - heads: int, - dim_head: int, - scale: float, - split_head_dim: bool, - float32_qk_product: bool, - use_memory_efficient_attention: bool + query: Array, + key: Array, + value: Array, + dtype: jnp.dtype, + heads: int, + dim_head: int, + scale: float, + split_head_dim: bool, + float32_qk_product: bool, + use_memory_efficient_attention: bool, ): """Apply Attention.""" if split_head_dim: @@ -270,14 +280,8 @@ def _apply_attention_dot( return hidden_states -def _cudnn_flash_attention( - query: Array, - key: Array, - value: Array, - heads: int, - mesh: Mesh, - dpa_layer: Callable -) -> Array: + +def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, mesh: Mesh, dpa_layer: Callable) -> Array: """CUDNN Flash Attention with Transformer Engine. 1. Stable API, supports GQA 2. Supports head_dim till 128; head_dim=256 support will be added soon @@ -308,24 +312,25 @@ def wrap_flash_attention(query, key, value): out = wrap_flash_attention(query, key, value) return _reshape_data_from_cudnn_flash(out) + def _apply_attention( - query: Array, - key: Array, - value: Array, - heads: int, - dim_head: int, - split_head_dim: bool, - float32_qk_product: bool, - attention_kernel: str, - flash_min_seq_length: int, - use_memory_efficient_attention: bool, - scale: float, - dtype: jnp.dtype, - mesh: Mesh, - flash_axis_names: AxisNames, - flash_block_sizes: BlockSizes, - dpa_layer: Callable - ): + query: Array, + key: Array, + value: Array, + heads: int, + dim_head: int, + split_head_dim: bool, + float32_qk_product: bool, + attention_kernel: str, + flash_min_seq_length: int, + use_memory_efficient_attention: bool, + scale: float, + dtype: jnp.dtype, + mesh: Mesh, + flash_axis_names: AxisNames, + flash_block_sizes: BlockSizes, + dpa_layer: Callable, +): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) seq_len_idx = 1 @@ -341,7 +346,9 @@ def _apply_attention( can_use_flash_attention = True if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: - return _apply_attention_dot(query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention) + return _apply_attention_dot( + query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention + ) elif attention_kernel == "flash": return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype) elif attention_kernel == "cudnn_flash_te": @@ -349,6 +356,7 @@ def _apply_attention( else: raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") + def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] @@ -454,26 +462,27 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: class NNXAttentionOp(nnx.Module): + def __init__( - self, - mesh: Mesh, - attention_kernel: str, - scale: int, - heads: int, - dim_head: int, - use_memory_efficient_attention: bool = False, - split_head_dim: bool = False, - float32_qk_product: bool = True, - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV), - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - dtype: DType = jnp.float32, - quant: Quant = None, + self, + mesh: Mesh, + attention_kernel: str, + scale: int, + heads: int, + dim_head: int, + use_memory_efficient_attention: bool = False, + split_head_dim: bool = False, + float32_qk_product: bool = True, + flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + dtype: DType = jnp.float32, + quant: Quant = None, ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"{self} has not been tested with {attention_kernel}") - + self.mesh = mesh self.scale = scale self.heads = heads @@ -481,33 +490,34 @@ def __init__( self.attention_kernel = attention_kernel self.use_memory_efficient_attention = use_memory_efficient_attention self.split_head_dim = split_head_dim - self.float32_qk_product=float32_qk_product - self.flash_axis_names=flash_axis_names - self.flash_min_seq_length=flash_min_seq_length - self.flash_block_sizes=flash_block_sizes - self.dtype=dtype - self.quant=quant - + self.float32_qk_product = float32_qk_product + self.flash_axis_names = flash_axis_names + self.flash_min_seq_length = flash_min_seq_length + self.flash_block_sizes = flash_block_sizes + self.dtype = dtype + self.quant = quant + def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( - query=query, - key=key, - value=value, - heads=self.heads, - dim_head=self.dim_head, - split_head_dim=self.split_head_dim, - float32_qk_product=self.float32_qk_product, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - use_memory_efficient_attention=self.use_memory_efficient_attention, - scale=self.scale, - dtype=self.dtype, - mesh=self.mesh, - flash_axis_names=self.flash_axis_names, - flash_block_sizes=self.flash_block_sizes, - dpa_layer=self.dpa_layer + query=query, + key=key, + value=value, + heads=self.heads, + dim_head=self.dim_head, + split_head_dim=self.split_head_dim, + float32_qk_product=self.float32_qk_product, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + scale=self.scale, + dtype=self.dtype, + mesh=self.mesh, + flash_axis_names=self.flash_axis_names, + flash_block_sizes=self.flash_block_sizes, + dpa_layer=self.dpa_layer, ) + class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str @@ -545,51 +555,52 @@ def setup(self): def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( - query=query, - key=key, - value=value, - heads=self.heads, - dim_head=self.dim_head, - split_head_dim=self.split_head_dim, - float32_qk_product=self.float32_qk_product, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - use_memory_efficient_attention=self.use_memory_efficient_attention, - scale=self.scale, - dtype=self.dtype, - mesh=self.mesh, - flash_axis_names=self.flash_axis_names, - flash_block_sizes=self.flash_block_sizes, - dpa_layer=self.dpa_layer + query=query, + key=key, + value=value, + heads=self.heads, + dim_head=self.dim_head, + split_head_dim=self.split_head_dim, + float32_qk_product=self.float32_qk_product, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + scale=self.scale, + dtype=self.dtype, + mesh=self.mesh, + flash_axis_names=self.flash_axis_names, + flash_block_sizes=self.flash_block_sizes, + dpa_layer=self.dpa_layer, ) class FlaxWanAttention(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - eps: float = 1e-6, - qk_norm: str = "rms_norm_across_heads", - use_memory_efficient_attention: bool = False, - split_head_dim: bool = False, - attention_kernel: str = "flash", - flash_min_seq_length: int = 4096, - flash_block_sizes: BlockSizes = None, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD), - key_axis_names: AxisNames = (BATCH, LENGTH, HEAD), - value_axis_names: AxisNames = (BATCH, LENGTH, HEAD), - out_axis_names: AxisNames = (BATCH, LENGTH, EMBED), - precision: jax.lax.Precision = None, - qkv_bias: bool = False, - quant: Quant = None, + self, + rngs: nnx.Rngs, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + eps: float = 1e-6, + qk_norm: str = "rms_norm_across_heads", + use_memory_efficient_attention: bool = False, + split_head_dim: bool = False, + attention_kernel: str = "flash", + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + precision: jax.lax.Precision = None, + qkv_bias: bool = False, + quant: Quant = None, ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -605,20 +616,20 @@ def __init__( self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names - + self.attention_op = NNXAttentionOp( - mesh=mesh, - attention_kernel=attention_kernel, - scale=scale, - heads=heads, - dim_head=dim_head, - use_memory_efficient_attention=use_memory_efficient_attention, - split_head_dim=split_head_dim, - float32_qk_product=False, - flash_min_seq_length=flash_min_seq_length, - flash_block_sizes=flash_block_sizes, - dtype=dtype, - quant=quant + mesh=mesh, + attention_kernel=attention_kernel, + scale=scale, + heads=heads, + dim_head=dim_head, + use_memory_efficient_attention=use_memory_efficient_attention, + split_head_dim=split_head_dim, + float32_qk_product=False, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + dtype=dtype, + quant=quant, ) kernel_axes = ("embed", "heads") @@ -655,33 +666,33 @@ def __init__( ) self.proj_attn = nnx.Linear( - rngs=rngs, - in_features=self.inner_dim, - out_features=self.inner_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) self.norm_q = None self.norm_k = None if qk_norm is not None: self.norm_q = nnx.RMSNorm( - num_features=self.inner_dim, - rngs=rngs, - epsilon=eps, - dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm", )), - param_dtype=weights_dtype + num_features=self.inner_dim, + rngs=rngs, + epsilon=eps, + dtype=dtype, + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + param_dtype=weights_dtype, ) self.norm_k = nnx.RMSNorm( - num_features=self.inner_dim, - rngs=rngs, - dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm", )), - param_dtype=weights_dtype + num_features=self.inner_dim, + rngs=rngs, + dtype=dtype, + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + param_dtype=weights_dtype, ) def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]: @@ -701,13 +712,10 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup return xq_out, xk_out def __call__( - self, - hidden_states: jax.Array, - encoder_hidden_states: jax.Array = None, - rotary_emb: Optional[jax.Array] = None + self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None ) -> jax.Array: - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec('data', 'fsdp','tensor')) - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec('data', 'fsdp','tensor')) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -724,12 +732,12 @@ def __call__( key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec('data', 'tensor', None, None)) - key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec('data', 'tensor', None, None)) - value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec('data', 'tensor', None, None)) + query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", "tensor", None, None)) + key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", "tensor", None, None)) + value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", "tensor", None, None)) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec('data', None, None)) + attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None)) attn_output = attn_output.astype(dtype=dtype) @@ -1318,6 +1326,7 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) + class FlaxFeedForward(nn.Module): r""" Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index ef57aaf63..d994b46e7 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -58,6 +58,7 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal + class NNXTimestepEmbedding(nnx.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. @@ -68,56 +69,69 @@ class NNXTimestepEmbedding(nnx.Module): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ + def __init__( - self, - rngs: nnx.Rngs, - in_channels: int, - time_embed_dim: int = 32, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim: int = None, - sample_proj_bias=True, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, + self, + rngs: nnx.Rngs, + in_channels: int, + time_embed_dim: int = 32, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim: int = None, + sample_proj_bias=True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.linear_1 = nnx.Linear( - rngs=rngs, - in_features=in_channels, - out_features=time_embed_dim, - use_bias=sample_proj_bias, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + rngs=rngs, + in_features=in_channels, + out_features=time_embed_dim, + use_bias=sample_proj_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "embed", + "mlp", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) if cond_proj_dim is not None: self.cond_proj = nnx.Linear( - rngs=rngs, + rngs=rngs, ) else: self.cond_proj = None - + self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - + self.linear_2 = nnx.Linear( - rngs=rngs, - in_features=time_embed_dim, - out_features=time_embed_dim_out, - use_bias=sample_proj_bias, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + rngs=rngs, + in_features=time_embed_dim, + out_features=time_embed_dim_out, + use_bias=sample_proj_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "mlp", + "embed", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) if post_act_fn is None: @@ -161,13 +175,15 @@ def __call__(self, temb): temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_2")(temb) return temb + class NNXFlaxTimesteps(nnx.Module): + def __init__( - self, - dim: int = 32, - flip_sin_to_cos: bool = False, - freq_shift: float = 1.0, - scale: int = 1, + self, + dim: int = 32, + flip_sin_to_cos: bool = False, + freq_shift: float = 1.0, + scale: int = 1, ): self.dim = dim self.flip_sin_to_cos = flip_sin_to_cos @@ -176,8 +192,9 @@ def __init__( def __call__(self, timesteps): return get_sinusoidal_embeddings( - timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift - ) + timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift + ) + class FlaxTimesteps(nn.Module): r""" @@ -207,7 +224,7 @@ def get_1d_rotary_pos_embed( linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32, - use_real: bool = True + use_real: bool = True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -230,52 +247,67 @@ def get_1d_rotary_pos_embed( out = jnp.exp(1j * freqs) return out + class NNXPixArtAlphaTextProjection(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, - in_features: int, - hidden_size: int, - out_features: int = None, - act_fn: str = "gelu_tanh", - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None + self, + rngs: nnx.Rngs, + in_features: int, + hidden_size: int, + out_features: int = None, + act_fn: str = "gelu_tanh", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): if out_features is None: out_features = hidden_size - + self.linear_1 = nnx.Linear( - rngs=rngs, - in_features=in_features, - out_features=hidden_size, - use_bias=True, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + rngs=rngs, + in_features=in_features, + out_features=hidden_size, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "embed", + "mlp", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) self.act_1 = get_activation(act_fn) self.linear_2 = nnx.Linear( - rngs=rngs, - in_features=hidden_size, - out_features=out_features, - use_bias=True, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + rngs=rngs, + in_features=hidden_size, + out_features=out_features, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "mlp", + "embed", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) - + def __call__(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index 2e300f70e..5e08a2eb8 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -43,7 +43,15 @@ logger = logging.get_logger(__name__) # gelu and gelu_tanh both use approximate=True by default -_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "gelu_tanh" : jax.nn.gelu, "mish": jax.nn.mish} +_ACTIVATIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + "relu": jax.nn.relu, + "gelu": jax.nn.gelu, + "gelu_tanh": jax.nn.gelu, + "mish": jax.nn.mish, +} + def get_activation(name: str): func = _ACTIVATIONS.get(name) @@ -51,6 +59,7 @@ def get_activation(name: str): raise ValueError(f"Unknown activation function: {name}") return func + class FlaxModelMixin(PushToHubMixin): r""" Base class for all Flax models. diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index 8c8463e62..2ba658d4b 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -149,19 +149,20 @@ def __call__(self, x, emb): raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") return x, gate_msa + class FP32LayerNorm(nnx.Module): - def __init__(self, rngs: nnx.Rngs, dim: int, eps : float, elementwise_affine: bool): + + def __init__(self, rngs: nnx.Rngs, dim: int, eps: float, elementwise_affine: bool): self.layer_norm = nnx.LayerNorm( - rngs=rngs, - num_features=dim, - epsilon=eps, - use_bias=elementwise_affine, - use_scale=elementwise_affine, - param_dtype=jnp.float32, - dtype=jnp.float32 + rngs=rngs, + num_features=dim, + epsilon=eps, + use_bias=elementwise_affine, + use_scale=elementwise_affine, + param_dtype=jnp.float32, + dtype=jnp.float32, ) + def __call__(self, inputs: jax.Array) -> jax.Array: origin_dtype = inputs.dtype - return self.layer_norm( - inputs.astype(dtype=jnp.float32) - ).astype(dtype=origin_dtype) + return self.layer_norm(inputs.astype(dtype=jnp.float32)).astype(dtype=origin_dtype) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index cf280fdff..19244f723 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -85,12 +85,10 @@ def __init__( use_bias=use_bias, padding="VALID", # Handle padding manually rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), - kernel_sharding - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), kernel_sharding), dtype=dtype, param_dtype=weights_dtype, - precision=precision + precision=precision, ) def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: @@ -197,17 +195,16 @@ def __init__( precision: jax.lax.Precision = None, ): self.conv = nnx.Conv( - dim, dim, - kernel_size=kernel_size, - strides=stride, - use_bias=True, - rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), - (None, None, None, None) - ), - dtype=dtype, - param_dtype=weights_dtype, - precision=precision + dim, + dim, + kernel_size=kernel_size, + strides=stride, + use_bias=True, + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, None)), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) def __call__(self, x): @@ -240,12 +237,10 @@ def __init__( padding="SAME", use_bias=True, rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), - (None, None, None, "conv_out") - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, - precision=precision + precision=precision, ), ) elif mode == "upsample3d": @@ -258,12 +253,10 @@ def __init__( padding="SAME", use_bias=True, rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), - (None, None, None, "conv_out") - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, - precision=precision + precision=precision, ), ) self.time_conv = WanCausalConv3d( @@ -275,29 +268,29 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) elif mode == "downsample2d": self.resample = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(3, 3), - stride=(2, 2), - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + dim=dim, + rngs=rngs, + kernel_size=(3, 3), + stride=(2, 2), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) elif mode == "downsample3d": self.resample = ZeroPaddedConv2D( - dim=dim, - rngs=rngs, - kernel_size=(3, 3), - stride=(2, 2), - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + dim=dim, + rngs=rngs, + kernel_size=(3, 3), + stride=(2, 2), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) self.time_conv = WanCausalConv3d( rngs=rngs, @@ -309,7 +302,7 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) else: self.resample = Identity() @@ -381,37 +374,38 @@ def __init__( # layers self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) self.conv1 = WanCausalConv3d( - rngs=rngs, - in_channels=in_dim, - out_channels=out_dim, - kernel_size=3, - padding=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) - self.conv2 = WanCausalConv3d(rngs=rngs, - in_channels=out_dim, - out_channels=out_dim, - kernel_size=3, - padding=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + self.conv2 = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=out_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) self.conv_shortcut = ( WanCausalConv3d( - rngs=rngs, - in_channels=in_dim, - out_channels=out_dim, - kernel_size=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) if in_dim != out_dim else Identity() @@ -463,26 +457,28 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, - ): + ): self.dim = dim self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) self.to_qkv = nnx.Conv( - in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), - (None, None, None, "conv_out") - ), - dtype=dtype, - param_dtype=weights_dtype, - precision=precision + in_features=dim, + out_features=dim * 3, + kernel_size=(1, 1), + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) self.proj = nnx.Conv( - in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), - (None, None, "conv_in", None) - ), - dtype=dtype, - param_dtype=weights_dtype, - precision=precision + in_features=dim, + out_features=dim, + kernel_size=(1, 1), + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "conv_in", None)), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) def __call__(self, x: jax.Array): @@ -515,7 +511,8 @@ def __call__(self, x: jax.Array): class WanMidBlock(nnx.Module): def __init__( - self, dim: int, + self, + dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", @@ -524,45 +521,38 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, - ): + ): self.dim = dim resnets = [ - WanResidualBlock( - in_dim=dim, - out_dim=dim, - rngs=rngs, - dropout=dropout, - non_linearity=non_linearity, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision - ) + WanResidualBlock( + in_dim=dim, + out_dim=dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) ] attentions = [] for _ in range(num_layers): attentions.append( - WanAttentionBlock( - dim=dim, - rngs=rngs, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision - ) + WanAttentionBlock(dim=dim, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision) ) resnets.append( - WanResidualBlock( - in_dim=dim, - out_dim=dim, - rngs=rngs, - dropout=dropout, - non_linearity=non_linearity, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision - ) + WanResidualBlock( + in_dim=dim, + out_dim=dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) ) self.attentions = attentions self.resnets = resnets @@ -599,15 +589,15 @@ def __init__( for _ in range(num_res_blocks + 1): resnets.append( WanResidualBlock( - in_dim=current_dim, - out_dim=out_dim, - dropout=dropout, - non_linearity=non_linearity, - rngs=rngs, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + in_dim=current_dim, + out_dim=out_dim, + dropout=dropout, + non_linearity=non_linearity, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) ) current_dim = out_dim @@ -617,15 +607,15 @@ def __init__( self.upsamplers = None if upsample_mode is not None: self.upsamplers = [ - WanResample( - dim=out_dim, - mode=upsample_mode, - rngs=rngs, - mesh=mesh, - weights_dtype=weights_dtype, - dtype=dtype, - precision=precision - ) + WanResample( + dim=out_dim, + mode=upsample_mode, + rngs=rngs, + mesh=mesh, + weights_dtype=weights_dtype, + dtype=dtype, + precision=precision, + ) ] def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): @@ -683,7 +673,7 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) # downsample blocks @@ -692,25 +682,21 @@ def __init__( # residual (+attention) blocks for _ in range(num_res_blocks): self.down_blocks.append( - WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - dropout=dropout,rngs=rngs, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision - ) + WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) ) if scale in attn_scales: self.down_blocks.append( - WanAttentionBlock( - dim=out_dim, - rngs=rngs, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + WanAttentionBlock( + dim=out_dim, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision ) ) in_dim = out_dim @@ -719,15 +705,9 @@ def __init__( if i != len(dim_mult) - 1: mode = "downsample3d" if temperal_downsample[i] else "downsample2d" self.down_blocks.append( - WanResample( - out_dim, - mode=mode, - rngs=rngs, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision - ) + WanResample( + out_dim, mode=mode, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision + ) ) scale /= 2.0 @@ -741,21 +721,21 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) # output blocks self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs) self.conv_out = WanCausalConv3d( - rngs=rngs, - in_channels=out_dim, - out_channels=z_dim, - kernel_size=3, - padding=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + rngs=rngs, + in_channels=out_dim, + out_channels=z_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): @@ -838,28 +818,29 @@ def __init__( scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block - self.conv_in = WanCausalConv3d(rngs=rngs, - in_channels=z_dim, - out_channels=dims[0], - kernel_size=3, - padding=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + self.conv_in = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim, + out_channels=dims[0], + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) # middle_blocks self.mid_block = WanMidBlock( - dim=dims[0], - rngs=rngs, - dropout=dropout, - non_linearity=non_linearity, - num_layers=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + dim=dims[0], + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) # upsample blocks @@ -885,7 +866,7 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) self.up_blocks.append(up_block) @@ -895,15 +876,16 @@ def __init__( # output blocks self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) - self.conv_out = WanCausalConv3d(rngs=rngs, - in_channels=out_dim, - out_channels=3, - kernel_size=3, - padding=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=3, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): @@ -1039,19 +1021,19 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) self.quant_conv = WanCausalConv3d( - rngs=rngs, - in_channels=z_dim * 2, - out_channels=z_dim * 2, - kernel_size=1, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + rngs=rngs, + in_channels=z_dim * 2, + out_channels=z_dim * 2, + kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) - + self.post_quant_conv = WanCausalConv3d( rngs=rngs, in_channels=z_dim, @@ -1060,7 +1042,7 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) self.decoder = WanDecoder3d( @@ -1075,7 +1057,7 @@ def __init__( mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, - precision=precision + precision=precision, ) def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): @@ -1133,9 +1115,9 @@ def _decode( fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] # When batch_size is 0, expand batch dim for contatenation # else, expand frame dim for concatenation so that batch dim stays intact. - axis=0 + axis = 0 if fm1.shape[0] > 1: - axis=1 + axis = 1 if len(fm1.shape) == 4: fm1 = jnp.expand_dims(fm1, axis=axis) diff --git a/src/maxdiffusion/models/wan/transformers/__init__.py b/src/maxdiffusion/models/wan/transformers/__init__.py index 522c1e64b..9ff757fc3 100644 --- a/src/maxdiffusion/models/wan/transformers/__init__.py +++ b/src/maxdiffusion/models/wan/transformers/__init__.py @@ -12,4 +12,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -""" \ No newline at end of file +""" diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index a0312b5bf..c79a21bf7 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -14,7 +14,7 @@ limitations under the License. """ -from typing import Tuple, Optional, Dict, Union, Any, List +from typing import Tuple, Optional, Dict, Union, Any import math import jax import jax.numpy as jnp @@ -24,28 +24,23 @@ from ...modeling_flax_utils import FlaxModelMixin, get_activation from ....configuration_utils import ConfigMixin, register_to_config from ...embeddings_flax import ( - get_1d_rotary_pos_embed, - NNXFlaxTimesteps, - NNXTimestepEmbedding, - NNXPixArtAlphaTextProjection + get_1d_rotary_pos_embed, + NNXFlaxTimesteps, + NNXTimestepEmbedding, + NNXPixArtAlphaTextProjection, ) from ...normalization_flax import FP32LayerNorm from ...attention_flax import FlaxWanAttention BlockSizes = common_types.BlockSizes + def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim freqs = [] for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed( - dim, - max_seq_len, - theta, - freqs_dtype=jnp.float64, - use_real=False - ) + freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False) freqs.append(freq) freqs = jnp.concatenate(freqs, axis=1) # sizes = jnp.array([ @@ -57,28 +52,24 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): # split_indices = cumulative_sizes[:-1] t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6) hw_size = attention_head_dim // 6 - + dims = [t_size, hw_size, hw_size] - + # Calculate split indices as a static list of integers cumulative_sizes = np.cumsum(dims) split_indices = cumulative_sizes[:-1].tolist() freqs_split = jnp.split(freqs, split_indices, axis=1) return freqs_split + class WanRotaryPosEmbed(nnx.Module): - def __init__( - self, - attention_head_dim: int, - patch_size: Tuple[int, int, int], - max_seq_len: int, - theta: float = 10000.0 - ): + + def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0): self.attention_head_dim = attention_head_dim self.patch_size = patch_size self.max_seq_len = max_seq_len self.theta = theta - + def __call__(self, hidden_states: jax.Array) -> jax.Array: _, num_frames, height, width, _ = hidden_states.shape p_t, p_h, p_w = self.patch_size @@ -101,53 +92,59 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array: class WanTimeTextImageEmbedding(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, - dim: int, - time_freq_dim: int, - time_proj_dim: int, - text_embed_dim: int, - image_embed_dim: Optional[int] = None, - pos_embed_seq_len: Optional[int] = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, + self, + rngs: nnx.Rngs, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): - self.timesteps_proj = NNXFlaxTimesteps( - dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0 - ) + self.timesteps_proj = NNXFlaxTimesteps(dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0) self.time_embedder = NNXTimestepEmbedding( - rngs=rngs, in_channels=time_freq_dim, time_embed_dim=dim, - dtype=dtype, weights_dtype=weights_dtype, precision=precision + rngs=rngs, + in_channels=time_freq_dim, + time_embed_dim=dim, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) self.act_fn = get_activation("silu") self.time_proj = nnx.Linear( - rngs=rngs, - in_features=dim, - out_features=time_proj_dim, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + rngs=rngs, + in_features=dim, + out_features=time_proj_dim, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "embed", + "mlp", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) self.text_embedder = NNXPixArtAlphaTextProjection( - rngs=rngs, - in_features=text_embed_dim, - hidden_size=dim, - act_fn="gelu_tanh", + rngs=rngs, + in_features=text_embed_dim, + hidden_size=dim, + act_fn="gelu_tanh", ) - + def __call__( - self, - timestep: jax.Array, - encoder_hidden_states: jax.Array, - encoder_hidden_states_image: Optional[jax.Array] = None + self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None ): timestep = self.timesteps_proj(timestep) temb = self.time_embedder(timestep) - + timestep_proj = self.time_proj(self.act_fn(temb)) encoder_hidden_states = self.text_embedder(encoder_hidden_states) @@ -161,83 +158,85 @@ class ApproximateGELU(nnx.Module): The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this [paper](https://arxiv.org/abs/1606.08415). """ + def __init__( - self, - rngs: nnx.Rngs, - dim_in: int, - dim_out: int, - bias: bool, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, + self, + rngs: nnx.Rngs, + dim_in: int, + dim_out: int, + bias: bool, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.proj = nnx.Linear( - rngs=rngs, - in_features=dim_in, - out_features=dim_out, - use_bias=bias, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, + rngs=rngs, + in_features=dim_in, + out_features=dim_out, + use_bias=bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - + def __call__(self, x: jax.Array) -> jax.Array: x = self.proj(x) return nnx.gelu(x) - + class WanFeedForward(nnx.Module): + def __init__( - self, - rngs: nnx.Rngs, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - inner_dim: int = None, - bias: bool = True, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, + self, + rngs: nnx.Rngs, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim: int = None, + bias: bool = True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - + self.act_fn = None if activation_fn == "gelu-approximate": self.act_fn = ApproximateGELU( - rngs=rngs, - dim_in=dim, - dim_out=inner_dim, - bias=bias, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision ) else: raise NotImplementedError(f"{activation_fn} is not implemented.") self.proj_out = nnx.Linear( - rngs=rngs, - in_features=inner_dim, - out_features=dim_out, - use_bias=bias, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)), + rngs=rngs, + in_features=inner_dim, + out_features=dim_out, + use_bias=bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "mlp", + "embed", + ), + ), ) - + def __call__(self, hidden_states: jax.Array) -> jax.Array: hidden_states = self.act_fn(hidden_states) return self.proj_out(hidden_states) - class WanTransformerBlock(nnx.Module): + def __init__( self, rngs: nnx.Rngs, @@ -247,7 +246,7 @@ def __init__( qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, - # In torch, this is none, so it can be ignored. + # In torch, this is none, so it can be ignored. # added_kv_proj_dim: Optional[int] = None, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, @@ -256,86 +255,72 @@ def __init__( weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, attention: str = "dot_product", - ): - + # 1. Self-attention - self.norm1 = FP32LayerNorm( - rngs=rngs, - dim=dim, - eps=eps, - elementwise_affine=False - ) + self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) self.attn1 = FlaxWanAttention( - rngs=rngs, - query_dim=dim, - heads=num_heads, - dim_head= dim // num_heads, - qk_norm=qk_norm, - eps=eps, - flash_min_seq_length=flash_min_seq_length, - flash_block_sizes=flash_block_sizes, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision, - attention_kernel=attention + rngs=rngs, + query_dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention_kernel=attention, ) # 1. Cross-attention self.attn2 = FlaxWanAttention( - rngs=rngs, - query_dim=dim, - heads=num_heads, - dim_head= dim // num_heads, - qk_norm=qk_norm, - eps=eps, - flash_min_seq_length=flash_min_seq_length, - flash_block_sizes=flash_block_sizes, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision, - attention_kernel=attention - ) - assert cross_attn_norm == True - self.norm2 = FP32LayerNorm( - rngs=rngs, - dim=dim, - eps=eps, - elementwise_affine=True + rngs=rngs, + query_dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention_kernel=attention, ) + assert cross_attn_norm is True + self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) # 3. Feed-forward self.ffn = WanFeedForward( - rngs=rngs, - dim=dim, - inner_dim=ffn_dim, - activation_fn="gelu-approximate", - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision + rngs=rngs, + dim=dim, + inner_dim=ffn_dim, + activation_fn="gelu-approximate", + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) - + key = rngs.params() self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) - - def __call__( - self, - hidden_states: jax.Array, - encoder_hidden_states: jax.Array, - temb: jax.Array, - rotary_emb: jax.Array - ): + + def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + (self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) - # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) - attn_output = self.attn1(hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + hidden_states.dtype + ) + attn_output = self.attn1( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb + ) hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention @@ -344,14 +329,18 @@ def __call__( hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) + norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + hidden_states.dtype + ) ff_output = self.ffn(norm_hidden_states) - hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype) + hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( + hidden_states.dtype + ) return hidden_states - + class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): - + @register_to_config def __init__( self, @@ -384,7 +373,7 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - #1. Patch & position embedding + # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nnx.Conv( in_channels, @@ -395,69 +384,78 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, None, "conv_out",)), + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + None, + None, + None, + None, + "conv_out", + ), + ), ) # 2. Condition embeddings # image_embedding_dim=1280 for I2V model self.condition_embedder = WanTimeTextImageEmbedding( - rngs=rngs, - dim=inner_dim, - time_freq_dim=freq_dim, - time_proj_dim=inner_dim * 6, - text_embed_dim=text_dim, - image_embed_dim=image_dim, - pos_embed_seq_len=pos_embed_seq_len + rngs=rngs, + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, ) # 3. Transformer blocks blocks = [] for _ in range(num_layers): block = WanTransformerBlock( - rngs=rngs, - dim=inner_dim, - ffn_dim=ffn_dim, - num_heads=num_attention_heads, - qk_norm=qk_norm, - cross_attn_norm=cross_attn_norm, - eps=eps, - flash_min_seq_length=flash_min_seq_length, - flash_block_sizes=flash_block_sizes, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision, - attention=attention + rngs=rngs, + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention=attention, ) blocks.append(block) self.blocks = blocks self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( - rngs=rngs, - in_features=inner_dim, - out_features=out_channels * math.prod(patch_size), - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), + rngs=rngs, + in_features=inner_dim, + out_features=out_channels * math.prod(patch_size), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) key = rngs.params() self.scale_shift_table = nnx.Param( - jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")) + jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), ) def __call__( - self, - hidden_states: jax.Array, - timestep: jax.Array, - encoder_hidden_states: jax.Array, - is_uncond: jax.Array, # jnp.bool_ scalar - slg_mask: jax.Array, # jnp.bool_ array of shape (num_blocks,) - encoder_hidden_states_image: Optional[jax.Array] = None, - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, + self, + hidden_states: jax.Array, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + is_uncond: jax.Array, # jnp.bool_ scalar + slg_mask: jax.Array, # jnp.bool_ array of shape (num_blocks,) + encoder_hidden_states_image: Optional[jax.Array] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[jax.Array, Dict[str, jax.Array]]: batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size @@ -465,14 +463,13 @@ def __call__( post_patch_height = height // p_h post_patch_width = width // p_w - hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image + timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) @@ -481,19 +478,21 @@ def __call__( for block_idx, block in enumerate(self.blocks): should_skip_block = slg_mask[block_idx] & is_uncond hidden_states = jax.lax.cond( - should_skip_block, - lambda hs: hs, # If true, pass through original hidden_states (skip block) - lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), - hidden_states + should_skip_block, + lambda hs: hs, # If true, pass through original hidden_states (skip block) + lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), + hidden_states, ) shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) hidden_states = jax.lax.collapse(hidden_states, 6, None) hidden_states = jax.lax.collapse(hidden_states, 4, 6) hidden_states = jax.lax.collapse(hidden_states, 2, 4) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 4f7effad6..f84346735 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -17,31 +17,33 @@ def _tuple_str_to_int(in_tuple): out_list.append(item) return tuple(out_list) + def rename_for_nnx(key): new_key = key if "norm_k" in key or "norm_q" in key: - new_key = key[:-1] + ("scale",) + new_key = key[:-1] + ("scale",) return new_key + def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] with jax.default_device(device): if hf_download: # download the index file for sharded models. - index_file_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json") + index_file_path = hf_hub_download( + pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json" + ) # open the index file. - with open(index_file_path, 'r') as f: + with open(index_file_path, "r") as f: index_dict = json.load(f) model_files = set() for key in index_dict["weight_map"].keys(): model_files.add(index_dict["weight_map"][key]) - + model_files = list(model_files) tensors = {} for model_file in model_files: - ckpt_shard_path = hf_hub_download( - pretrained_model_name_or_path, subfolder="transformer", filename=model_file - ) + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file) # now get all the filenames for the model that need downloading max_logging.log(f"Load and port Wan 2.1 transformer on {device}") @@ -52,7 +54,7 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, flax_state_dict = {} cpu = jax.local_devices(backend="cpu")[0] flattened_dict = flatten_dict(eval_shapes) - # turn all block numbers to strings just for matching weights. + # turn all block numbers to strings just for matching weights. # Later they will be turned back to ints. random_flax_state_dict = {} for key in flattened_dict: @@ -67,7 +69,7 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") pt_tuple_key = tuple(renamed_pt_key.split(".")) - + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) @@ -78,6 +80,7 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, jax.clear_caches() return flax_state_dict + def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] with jax.default_device(device): diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py index dab0ec292..83a537f82 100644 --- a/src/maxdiffusion/pipelines/wan/__init__.py +++ b/src/maxdiffusion/pipelines/wan/__init__.py @@ -14,4 +14,4 @@ limitations under the License. """ -from .wan_pipeline import WanPipeline \ No newline at end of file +from .wan_pipeline import WanPipeline diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 3298f769b..80637f9a8 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -36,6 +36,7 @@ import re import torch + def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) @@ -66,10 +67,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): return wan_transformer # 1. Load config. - wan_config = WanModel.load_config( - config.pretrained_model_name_or_path, - subfolder="transformer" - ) + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype @@ -97,15 +95,13 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value - try: - state[path].value = jax.device_put(val, sharding) - except: - raise ValueError("value should exist.") + state[path].value = jax.device_put(val, sharding) state = nnx.from_flat_state(state) wan_transformer = nnx.merge(graphdef, state, rest_of_state) return wan_transformer + @nnx.jit(static_argnums=(1,), donate_argnums=(0,)) def create_sharded_logical_model(model, logical_axis_rules): graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) @@ -134,18 +130,19 @@ class WanPipeline: vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ + def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - transformer: WanModel, - vae: AutoencoderKLWan, - vae_cache: AutoencoderKLWanCache, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state: UniPCMultistepSchedulerState, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanModel, + vae: AutoencoderKLWan, + vae_cache: AutoencoderKLWanCache, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state: UniPCMultistepSchedulerState, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, ): self.tokenizer = tokenizer self.text_encoder = text_encoder @@ -164,34 +161,31 @@ def __init__( self.p_run_inference = None - @classmethod def load_text_encoder(cls, config: HyperParameters): text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder", + config.pretrained_model_name_or_path, + subfolder="text_encoder", ) return text_encoder - @classmethod def load_tokenizer(cls, config: HyperParameters): tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="tokenizer", + config.pretrained_model_name_or_path, + subfolder="tokenizer", ) return tokenizer - @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, ) vae_cache = AutoencoderKLWanCache(wan_vae) @@ -208,79 +202,75 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H wan_vae = p_create_sharded_logical_model(model=wan_vae) return wan_vae, vae_cache - @classmethod def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): with mesh: wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) return wan_transformer - @classmethod def load_scheduler(cls, config): scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="scheduler", - flow_shift=config.flow_shift # 5.0 for 720p, 3.0 for 480p + config.pretrained_model_name_or_path, + subfolder="scheduler", + flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p ) return scheduler, scheduler_state - - @classmethod + @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - transformer=None - tokenizer=None - scheduler=None - scheduler_state=None - text_encoder=None + transformer = None + tokenizer = None + scheduler = None + scheduler_state = None + text_encoder = None if not vae_only: with mesh: transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - + text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) - + scheduler, scheduler_state = cls.load_scheduler(config=config) - + with mesh: wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) return WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - transformer=transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=wan_vae, + vae_cache=vae_cache, + scheduler=scheduler, + scheduler_state=scheduler_state, + devices_array=devices_array, + mesh=mesh, + config=config, ) - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, ): prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt_clean(u) for u in prompt] batch_size = len(prompt) text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() @@ -296,24 +286,23 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds - def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - if prompt_embeds is None: + if prompt_embeds is None: prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, ) prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) @@ -321,62 +310,60 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_embeds = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, ) negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) return prompt_embeds, negative_prompt_embeds - def prepare_latents( - self, - batch_size: int, - vae_scale_factor_temporal: int, - vae_scale_factor_spatial: int, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_channels_latents: int = 16, + self, + batch_size: int, + vae_scale_factor_temporal: int, + vae_scale_factor_spatial: int, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_channels_latents: int = 16, ): rng = jax.random.key(self.config.seed) num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 shape = ( - batch_size, - num_channels_latents, - num_latent_frames, - int(height) // vae_scale_factor_spatial, - int(width) // vae_scale_factor_spatial, + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // vae_scale_factor_spatial, + int(width) // vae_scale_factor_spatial, ) latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype) return latents - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - slg_layers: List[int] = None, - slg_start: float = 0.0, - slg_end: float = 1.0 + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + slg_layers: List[int] = None, + slg_start: float = 0.0, + slg_end: float = 1.0, ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: max_logging.log( - f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) @@ -384,63 +371,63 @@ def __call__( # 2. Define call parameters if prompt is not None and isinstance(prompt, str): prompt = [prompt] - + batch_size = len(prompt) - + prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) num_channel_latents = self.transformer.config.in_channels if latents is None: latents = self.prepare_latents( - batch_size=batch_size, - vae_scale_factor_temporal=self.vae_scale_factor_temporal, - vae_scale_factor_spatial=self.vae_scale_factor_spatial, - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channel_latents + batch_size=batch_size, + vae_scale_factor_temporal=self.vae_scale_factor_temporal, + vae_scale_factor_spatial=self.vae_scale_factor_spatial, + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channel_latents, ) data_sharding = PositionalSharding(self.devices_array).replicate() if len(prompt) % jax.device_count() == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - + latents = jax.device_put(latents, data_sharding) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) p_run_inference = partial( - run_inference, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, - num_transformer_layers=self.transformer.config.num_layers + run_inference, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, + num_transformer_layers=self.transformer.config.num_layers, ) with self.mesh: latents = p_run_inference( - graphdef=graphdef, - sharded_state=state, - rest_of_state=rest_of_state, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) @@ -456,49 +443,38 @@ def __call__( @jax.jit -def transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_embeds, - is_uncond, - slg_mask): +def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) return wan_transformer( - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - is_uncond=is_uncond, - slg_mask=slg_mask + hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask ) + def run_inference( - graphdef, - sharded_state, - rest_of_state, - latents: jnp.array, - prompt_embeds: jnp.array, - negative_prompt_embeds: jnp.array, - guidance_scale: float, - num_inference_steps: int, - scheduler : FlaxUniPCMultistepScheduler, - num_transformer_layers: int, - scheduler_state, - slg_layers: List[int] = None, - slg_start: float = 0.0, - slg_end: float = 1.0 - ): - do_classifier_free_guidance = guidance_scale > 1.0 - for step in range(num_inference_steps): - slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) - if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): - slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - timestep = jnp.broadcast_to(t, latents.shape[0]) - - noise_pred = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale: float, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + num_transformer_layers: int, + scheduler_state, + slg_layers: List[int] = None, + slg_start: float = 0.0, + slg_end: float = 1.0, +): + do_classifier_free_guidance = guidance_scale > 1.0 + for step in range(num_inference_steps): + slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) + if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): + slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents.shape[0]) + + noise_pred = transformer_forward_pass( graphdef, sharded_state, rest_of_state, @@ -506,11 +482,11 @@ def run_inference( timestep, prompt_embeds, is_uncond=jnp.array(False, dtype=jnp.bool_), - slg_mask=slg_mask - ) + slg_mask=slg_mask, + ) - if do_classifier_free_guidance: - noise_uncond = transformer_forward_pass( + if do_classifier_free_guidance: + noise_uncond = transformer_forward_pass( graphdef, sharded_state, rest_of_state, @@ -518,8 +494,8 @@ def run_inference( timestep, negative_prompt_embeds, is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=slg_mask - ) - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents \ No newline at end of file + slg_mask=slg_mask, + ) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index 85e116137..b04a142de 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -33,875 +33,770 @@ @flax.struct.dataclass class UniPCMultistepSchedulerState: - """ - Data class to hold the mutable state of the FlaxUniPCMultistepScheduler. - """ - - common: CommonSchedulerState - - # Core schedule parameters (derived from CommonSchedulerState in create_state) - sigmas: jnp.ndarray - alpha_t: jnp.ndarray - sigma_t: jnp.ndarray - lambda_t: jnp.ndarray - init_noise_sigma: float - - # History buffers for multi-step solver - # `model_outputs` stores previous converted model outputs (e.g., predicted x0 or epsilon) - timesteps: jnp.ndarray = None - model_outputs: jnp.ndarray = None - timestep_list: jnp.ndarray = ( - None # Stores corresponding timesteps for `model_outputs` + """ + Data class to hold the mutable state of the FlaxUniPCMultistepScheduler. + """ + + common: CommonSchedulerState + + # Core schedule parameters (derived from CommonSchedulerState in create_state) + sigmas: jnp.ndarray + alpha_t: jnp.ndarray + sigma_t: jnp.ndarray + lambda_t: jnp.ndarray + init_noise_sigma: float + + # History buffers for multi-step solver + # `model_outputs` stores previous converted model outputs (e.g., predicted x0 or epsilon) + timesteps: jnp.ndarray = None + model_outputs: jnp.ndarray = None + timestep_list: jnp.ndarray = None # Stores corresponding timesteps for `model_outputs` + + # State variables for tracking progress and solver order + lower_order_nums: int = 0 + last_sample: Optional[jnp.ndarray] = None # Sample from the previous predictor step + step_index: Optional[int] = None + begin_index: Optional[int] = None # Used for img2img/inpaing + this_order: int = 0 # Current effective order of the UniPC solver for this step + + @classmethod + def create( + cls, + common_state: CommonSchedulerState, + alpha_t: jnp.ndarray, + sigma_t: jnp.ndarray, + lambda_t: jnp.ndarray, + sigmas: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + ): + return cls( + common=common_state, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + lower_order_nums=0, + last_sample=None, + step_index=None, + begin_index=None, + this_order=0, ) - # State variables for tracking progress and solver order - lower_order_nums: int = 0 - last_sample: Optional[jnp.ndarray] = None # Sample from the previous predictor step - step_index: Optional[int] = None - begin_index: Optional[int] = None # Used for img2img/inpaing - this_order: int = 0 # Current effective order of the UniPC solver for this step - - @classmethod - def create( - cls, - common_state: CommonSchedulerState, - alpha_t: jnp.ndarray, - sigma_t: jnp.ndarray, - lambda_t: jnp.ndarray, - sigmas: jnp.ndarray, - init_noise_sigma: jnp.ndarray, - ): - return cls( - common=common_state, - alpha_t=alpha_t, - sigma_t=sigma_t, - lambda_t=lambda_t, - sigmas=sigmas, - init_noise_sigma=init_noise_sigma, - lower_order_nums=0, - last_sample=None, - step_index=None, - begin_index=None, - this_order=0, - ) - @flax.struct.dataclass(frozen=False) class FlaxUniPCMultistepSchedulerOutput(FlaxSchedulerOutput): - state: UniPCMultistepSchedulerState + state: UniPCMultistepSchedulerState class FlaxUniPCMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + `FlaxUniPCMultistepScheduler` is a JAX/Flax training-free framework designed for the fast sampling of diffusion models. + It implements the UniPC (Unified Predictor-Corrector) algorithm for efficient diffusion model sampling. + """ + + dtype: jnp.dtype + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: Optional[FlaxSchedulerMixin] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + use_flow_sigmas: Optional[bool] = False, + flow_shift: Optional[float] = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", + rescale_zero_terminal_snr: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + # Validation checks from original __init__ + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if self.config.solver_type not in ["bh1", "bh2"]: + raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> UniPCMultistepSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + if self.config.get("rescale_zero_terminal_snr", False): + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + alphas_cumprod = common.alphas_cumprod + alphas_cumprod = alphas_cumprod.at[-1].set(2**-24) + common = common.replace(alphas_cumprod=alphas_cumprod) + + # Currently we only support VP-type noise schedule + alpha_t = jnp.sqrt(common.alphas_cumprod) + sigma_t = jnp.sqrt(1 - common.alphas_cumprod) + lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + if self.config.solver_type not in ["bh1", "bh2"]: + if self.config.solver_type in ["midpoint", "heun", "logrho"]: + self.config.solver_type = "bh2" + else: + raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") + + return UniPCMultistepSchedulerState.create( + common_state=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + ) + + def set_begin_index(self, state: UniPCMultistepSchedulerState, begin_index: int = 0) -> UniPCMultistepSchedulerState: """ - `FlaxUniPCMultistepScheduler` is a JAX/Flax training-free framework designed for the fast sampling of diffusion models. - It implements the UniPC (Unified Predictor-Corrector) algorithm for efficient diffusion model sampling. + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. """ - - dtype: jnp.dtype - - @property - def has_state(self) -> bool: - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: List[int] = [], - solver_p: Optional[FlaxSchedulerMixin] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - use_flow_sigmas: Optional[bool] = False, - flow_shift: Optional[float] = 1.0, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - final_sigmas_type: Optional[str] = "zero", - rescale_zero_terminal_snr: bool = False, - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - # Validation checks from original __init__ - if self.config.use_beta_sigmas and not is_scipy_available(): - raise ImportError( - "Make sure to install scipy if you want to use beta sigmas." - ) - if ( - sum( - [ - self.config.use_beta_sigmas, - self.config.use_exponential_sigmas, - self.config.use_karras_sigmas, - ] - ) - > 1 - ): - raise ValueError( - "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." - ) - if self.config.solver_type not in ["bh1", "bh2"]: - raise NotImplementedError( - f"{self.config.solver_type} is not implemented for {self.__class__}" - ) - - def create_state( - self, common: Optional[CommonSchedulerState] = None - ) -> UniPCMultistepSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) - - if self.config.get("rescale_zero_terminal_snr", False): - # Close to 0 without being 0 so first sigma is not inf - # FP16 smallest positive subnormal works well here - alphas_cumprod = common.alphas_cumprod - alphas_cumprod = alphas_cumprod.at[-1].set(2**-24) - common = common.replace(alphas_cumprod=alphas_cumprod) - - # Currently we only support VP-type noise schedule - alpha_t = jnp.sqrt(common.alphas_cumprod) - sigma_t = jnp.sqrt(1 - common.alphas_cumprod) - lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) - sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 - - # standard deviation of the initial noise distribution - init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - - if self.config.solver_type not in ["bh1", "bh2"]: - if self.config.solver_type in ["midpoint", "heun", "logrho"]: - self.config.solver_type = "bh2" - else: - raise NotImplementedError( - f"{self.config.solver_type} is not implemented for {self.__class__}" - ) - - return UniPCMultistepSchedulerState.create( - common_state=common, - alpha_t=alpha_t, - sigma_t=sigma_t, - lambda_t=lambda_t, - sigmas=sigmas, - init_noise_sigma=init_noise_sigma, + return state.replace(begin_index=begin_index) + + def set_timesteps( + self, + state: UniPCMultistepSchedulerState, + num_inference_steps: int, + shape: Tuple, + ) -> UniPCMultistepSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + """ + #### Copied from scheduling_dmpsolver_multistep_flax + last_timestep = self.config.num_train_timesteps + if self.config.timestep_spacing == "linspace": + timesteps = jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + # initial running values + sigmas = state.sigmas + + # TODO + # # Apply Karras/Exponential/Beta/Flow Sigmas if configured + if self.config.use_karras_sigmas: + # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_karras_sigmas` is not implemented in JAX version yet.") + elif self.config.use_exponential_sigmas: + # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_exponential_sigmas` is not implemented in JAX version yet.") + elif self.config.use_beta_sigmas: + # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_beta_sigmas` is not implemented in JAX version yet.") + if self.config.use_flow_sigmas: + alphas = jnp.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) + sigmas = 1.0 - alphas + sigmas = jnp.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() + timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) - - def set_begin_index( - self, state: UniPCMultistepSchedulerState, begin_index: int = 0 - ) -> UniPCMultistepSchedulerState: - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - """ - return state.replace(begin_index=begin_index) - - def set_timesteps( - self, - state: UniPCMultistepSchedulerState, - num_inference_steps: int, - shape: Tuple, - ) -> UniPCMultistepSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - """ - #### Copied from scheduling_dmpsolver_multistep_flax - last_timestep = self.config.num_train_timesteps - if self.config.timestep_spacing == "linspace": - timesteps = ( - jnp.linspace(0, last_timestep - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .astype(jnp.int32) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (jnp.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(jnp.int32) - ) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - jnp.arange(last_timestep, 0, -step_ratio) - .round() - .copy() - .astype(jnp.int32) - ) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) - - # initial running values - sigmas = state.sigmas - - # TODO - # # Apply Karras/Exponential/Beta/Flow Sigmas if configured - if self.config.use_karras_sigmas: - # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - raise NotImplementedError( - "`use_karras_sigmas` is not implemented in JAX version yet." - ) - elif self.config.use_exponential_sigmas: - # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - raise NotImplementedError( - "`use_exponential_sigmas` is not implemented in JAX version yet." - ) - elif self.config.use_beta_sigmas: - # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - raise NotImplementedError( - "`use_beta_sigmas` is not implemented in JAX version yet." - ) - if self.config.use_flow_sigmas: - alphas = jnp.linspace( - 1, 1 / self.config.num_train_timesteps, num_inference_steps + 1 - ) - sigmas = 1.0 - alphas - sigmas = jnp.flip( - self.config.flow_shift - * sigmas - / (1 + (self.config.flow_shift - 1) * sigmas) - )[:-1].copy() - timesteps = ( - (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64) - ) - if self.config.final_sigmas_type == "sigma_min": - sigma_last = sigmas[-1] - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype( - jnp.float32 - ) - else: # Default case if none of the specialized sigmas are used - sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ( - (1 - state.common.alphas_cumprod[0]) - / state.common.alphas_cumprod[0] - ) ** 0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype( - jnp.float32 - ) - - model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) - timestep_list = jnp.zeros( - (self.config.solver_order,), dtype=jnp.int32 # Timesteps are integers + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32) + else: # Default case if none of the specialized sigmas are used + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - state.common.alphas_cumprod[0]) / state.common.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) - # Update the state with the new schedule and re-initialized history - return state.replace( - timesteps=timesteps, - sigmas=sigmas, - model_outputs=model_outputs, - timestep_list=timestep_list, - lower_order_nums=0, # Reset counters for a new inference run - step_index=None, - begin_index=None, - last_sample=None, - this_order=0, + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32) + + model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) + timestep_list = jnp.zeros((self.config.solver_order,), dtype=jnp.int32) # Timesteps are integers + # Update the state with the new schedule and re-initialized history + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + model_outputs=model_outputs, + timestep_list=timestep_list, + lower_order_nums=0, # Reset counters for a new inference run + step_index=None, + begin_index=None, + last_sample=None, + this_order=0, + ) + + def convert_model_output( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + Converts the model output based on the prediction type and current state. + """ + sigma = state.sigmas[state.step_index] # Current sigma + + # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.config.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + # Original code has `sigma_t = self.sigmas[self.step_index]`. + # This implies current sigma `sigma` is used as sigma_t for flow. + x0_pred = sample - sigma * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." ) - def convert_model_output( - self, - state: UniPCMultistepSchedulerState, - model_output: jnp.ndarray, - sample: jnp.ndarray, - ) -> jnp.ndarray: - """ - Converts the model output based on the prediction type and current state. - """ - sigma = state.sigmas[state.step_index] # Current sigma - - # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - - if self.config.predict_x0: - if self.config.prediction_type == "epsilon": - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - x0_pred = alpha_t * sample - sigma_t * model_output - elif self.config.prediction_type == "flow_prediction": - # Original code has `sigma_t = self.sigmas[self.step_index]`. - # This implies current sigma `sigma` is used as sigma_t for flow. - x0_pred = sample - sigma * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " - "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - raise NotImplementedError("Dynamic thresholding isn't implemented.") - # x0_pred = self._threshold_sample(x0_pred) - return x0_pred - else: # self.config.predict_x0 is False - if self.config.prediction_type == "epsilon": - return model_output - elif self.config.prediction_type == "sample": - epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon - elif self.config.prediction_type == "v_prediction": - epsilon = alpha_t * model_output + sigma_t * sample - return epsilon - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the UniPCMultistepScheduler." - ) - - def multistep_uni_p_bh_update( - self, - state: UniPCMultistepSchedulerState, - model_output: jnp.ndarray, - sample: jnp.ndarray, - order: int, - ) -> jnp.ndarray: - """ - One step for the UniP (B(h) version) - the Predictor. - """ - if self.config.solver_p: - raise NotImplementedError( - "Nested `solver_p` is not implemented in JAX version yet." - ) - - m0 = state.model_outputs[ - self.config.solver_order - 1 - ] # Most recent stored converted model output - x = sample - - sigma_t_val, sigma_s0_val = ( - state.sigmas[state.step_index + 1], - state.sigmas[state.step_index], + if self.config.thresholding: + raise NotImplementedError("Dynamic thresholding isn't implemented.") + # x0_pred = self._threshold_sample(x0_pred) + return x0_pred + else: # self.config.predict_x0 is False + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." ) - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + def multistep_uni_p_bh_update( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + order: int, + ) -> jnp.ndarray: + """ + One step for the UniP (B(h) version) - the Predictor. + """ + if self.config.solver_p: + raise NotImplementedError("Nested `solver_p` is not implemented in JAX version yet.") - lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) - lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + m0 = state.model_outputs[self.config.solver_order - 1] # Most recent stored converted model output + x = sample - h = lambda_t - lambda_s0 + sigma_t_val, sigma_s0_val = ( + state.sigmas[state.step_index + 1], + state.sigmas[state.step_index], + ) - def rk_d1_loop_body(i, carry): - # Loop from i = 0 to order-2 - rks, D1s = carry - history_idx = self.config.solver_order - 2 - i - mi = state.model_outputs[history_idx] - si_val = state.timestep_list[history_idx] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + + h = lambda_t - lambda_s0 + + def rk_d1_loop_body(i, carry): + # Loop from i = 0 to order-2 + rks, D1s = carry + history_idx = self.config.solver_order - 2 - i + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[history_idx] + + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)]) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + + rk = (lambda_si - lambda_s0) / h + Di = (mi - m0) / rk + + rks = rks.at[i].set(rk) + D1s = D1s.at[i].set(Di) + return rks, D1s + + rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) + if self.config.solver_order == 1: + # Dummy D1s array. It will not be used if order == 1 + D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) + rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init)) + rks = rks.at[order - 1].set(1.0) + + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() + + def rb_loop_body(i, carry): + R, b, current_h_phi_k, factorial_val = carry + R = R.at[i].set(jnp.power(rks, i)) + b = b.at[i].set(current_h_phi_k * factorial_val / B_h) + + def update_fn(vals): + _h_phi_k, _fac = vals + next_fac = _fac * (i + 2) + next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac + return next_h_phi_k, next_fac + + current_h_phi_k, factorial_val = jax.lax.cond( + i < order - 1, + update_fn, + lambda vals: vals, + (current_h_phi_k, factorial_val), + ) + return R, b, current_h_phi_k, factorial_val + + R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype) + b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + init_h_phi_k = h_phi_1 / hh - 1.0 + init_factorial = 1.0 + R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial)) + + if len(D1s) > 0: + D1s = jnp.stack(D1s, axis=1) # Resulting shape (B, K, C, H, W) + + def solve_for_rhos_p(R_mat, b_vec, current_order): + # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix + mask_size = self.config.solver_order - 1 + mask = jnp.arange(mask_size) < (current_order - 1) + mask_2d = mask[:, None] & mask[None, :] + + # Pad R with identity and b with zeros for a safe solve + R_safe = jnp.where( + mask_2d, + R_mat[:mask_size, :mask_size], + jnp.eye(mask_size, dtype=R_mat.dtype), + ) + b_safe = jnp.where(mask, b_vec[:mask_size], 0.0) + + # Solve the system and mask the result + solved_rhos = jnp.linalg.solve(R_safe, b_safe) + return jnp.where(mask, solved_rhos, 0.0) + + # Handle the special case for order == 2 + if self.config.solver_order == 1: + # Dummy rhos_p_padded for tracing. + rhos_p_order2 = jnp.zeros(1, dtype=x.dtype) + else: + rhos_p_order2 = jnp.zeros(self.config.solver_order - 1, dtype=x.dtype).at[0].set(0.5) + + # Get the result for the general case + rhos_p_general = solve_for_rhos_p(R, b, order) + + # Select the appropriate result based on the order + rhos_p = jnp.where(order == 2, rhos_p_order2, rhos_p_general) + + pred_res = jax.lax.cond( + order > 1, + lambda _: jnp.einsum("k,bkc...->bc...", rhos_p, D1s).astype(x.dtype), + # False branch: return a zero tensor with the correct shape. + lambda _: jnp.zeros_like(x), + operand=None, + ) - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t( - state.sigmas[self.index_for_timestep(state, si_val)] - ) - lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: # Predict epsilon + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + x_t = x_t_ - sigma_t * B_h * pred_res + + return x_t.astype(x.dtype) + + def multistep_uni_c_bh_update( + self, + state: UniPCMultistepSchedulerState, + this_model_output: jnp.ndarray, + last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}` + this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step) + order: int, + ) -> jnp.ndarray: + """ + One step for the UniC (B(h) version) - the Corrector. + """ + model_output_list = state.model_outputs + m0 = model_output_list[self.config.solver_order - 1] # Most recent model output from history - rk = (lambda_si - lambda_s0) / h - Di = (mi - m0) / rk + if last_sample is not None: + x = last_sample + else: + # If it's None, create dummy data. This is for the tracing purpose + x = jnp.zeros_like(this_sample) - rks = rks.at[i].set(rk) - D1s = D1s.at[i].set(Di) - return rks, D1s + x_t = this_sample - rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) - D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) - if self.config.solver_order == 1: - # Dummy D1s array. It will not be used if order == 1 - D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) - rks, D1s = jax.lax.fori_loop( - 0, order - 1, rk_d1_loop_body, (rks_init, D1s_init) - ) - rks = rks.at[order - 1].set(1.0) - - hh = -h if self.config.predict_x0 else h - h_phi_1 = jnp.expm1(hh) - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = jnp.expm1(hh) - else: - raise NotImplementedError() - - def rb_loop_body(i, carry): - R, b, current_h_phi_k, factorial_val = carry - R = R.at[i].set(jnp.power(rks, i)) - b = b.at[i].set(current_h_phi_k * factorial_val / B_h) - - def update_fn(vals): - _h_phi_k, _fac = vals - next_fac = _fac * (i + 2) - next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac - return next_h_phi_k, next_fac - - current_h_phi_k, factorial_val = jax.lax.cond( - i < order - 1, - update_fn, - lambda vals: vals, - (current_h_phi_k, factorial_val), - ) - return R, b, current_h_phi_k, factorial_val - - R_init = jnp.zeros( - (self.config.solver_order, self.config.solver_order), dtype=h.dtype - ) - b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) - init_h_phi_k = h_phi_1 / hh - 1.0 - init_factorial = 1.0 - R, b, _, _ = jax.lax.fori_loop( - 0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial) - ) + model_t = this_model_output - if len(D1s) > 0: - D1s = jnp.stack(D1s, axis=1) # Resulting shape (B, K, C, H, W) - - def solve_for_rhos_p(R_mat, b_vec, current_order): - # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix - mask_size = self.config.solver_order - 1 - mask = jnp.arange(mask_size) < (current_order - 1) - mask_2d = mask[:, None] & mask[None, :] - - # Pad R with identity and b with zeros for a safe solve - R_safe = jnp.where( - mask_2d, - R_mat[:mask_size, :mask_size], - jnp.eye(mask_size, dtype=R_mat.dtype), - ) - b_safe = jnp.where(mask, b_vec[:mask_size], 0.0) - - # Solve the system and mask the result - solved_rhos = jnp.linalg.solve(R_safe, b_safe) - return jnp.where(mask, solved_rhos, 0.0) - - # Handle the special case for order == 2 - if self.config.solver_order == 1: - # Dummy rhos_p_padded for tracing. - rhos_p_order2 = jnp.zeros(1, dtype=x.dtype) - else: - rhos_p_order2 = ( - jnp.zeros(self.config.solver_order - 1, dtype=x.dtype).at[0].set(0.5) - ) - - # Get the result for the general case - rhos_p_general = solve_for_rhos_p(R, b, order) - - # Select the appropriate result based on the order - rhos_p = jnp.where(order == 2, rhos_p_order2, rhos_p_general) - - pred_res = jax.lax.cond( - order > 1, - lambda _: jnp.einsum("k,bkc...->bc...", rhos_p, D1s).astype(x.dtype), - # False branch: return a zero tensor with the correct shape. - lambda _: jnp.zeros_like(x), - operand=None, - ) + sigma_t_val = state.sigmas[state.step_index] + sigma_s0_val = state.sigmas[state.step_index - 1] - if self.config.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - x_t = x_t_ - alpha_t * B_h * pred_res - else: # Predict epsilon - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - x_t = x_t_ - sigma_t * B_h * pred_res - - return x_t.astype(x.dtype) - - def multistep_uni_c_bh_update( - self, - state: UniPCMultistepSchedulerState, - this_model_output: jnp.ndarray, - last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}` - this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step) - order: int, - ) -> jnp.ndarray: - """ - One step for the UniC (B(h) version) - the Corrector. - """ - model_output_list = state.model_outputs - m0 = model_output_list[ - self.config.solver_order - 1 - ] # Most recent model output from history - - if last_sample is not None: - x = last_sample - else: - # If it's None, create dummy data. This is for the tracing purpose - x = jnp.zeros_like(this_sample) - - x_t = this_sample - - model_t = this_model_output - - sigma_t_val = state.sigmas[state.step_index] - sigma_s0_val = state.sigmas[state.step_index - 1] - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) - - lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) - lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) - - h = lambda_t - lambda_s0 - - def rk_d1_loop_body(i, carry): - # Loop from i = 0 to order-1. - rks, D1s = carry - - # Get history from state buffer - history_idx = self.config.solver_order - (i + 2) - mi = state.model_outputs[history_idx] - si_val = state.timestep_list[ - history_idx - ] # This is the actual timestep value - - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t( - state.sigmas[self.index_for_timestep(state, si_val)] - ) - lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) - - rk = (lambda_si - lambda_s0) / h - Di = (mi - m0) / rk - - # Update pre-allocated arrays - rks = rks.at[i].set(rk) - D1s = D1s.at[i].set(Di) - return rks, D1s - - # Pre-allocate arrays to max possible size - rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) - D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) - if self.config.solver_order == 1: - # Dummy D1s array. It will not be used if order == 1. This is for tracing. - D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) - - # Run the loop up to `order - 1` - rks, D1s = jax.lax.fori_loop( - 0, order - 1, rk_d1_loop_body, (rks_init, D1s_init) - ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) - rks = rks.at[order - 1].set(1.0) - - hh = -h if self.config.predict_x0 else h - h_phi_1 = jnp.expm1(hh) - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = jnp.expm1(hh) - else: - raise NotImplementedError() - - def rb_loop_body(i, carry): - # Loop from i = 0 to order-1 - R, b, current_h_phi_k, factorial_val = carry - - R = R.at[i].set(jnp.power(rks, i)) - b = b.at[i].set(current_h_phi_k * factorial_val / B_h) - - # Conditionally update phi_k and factorial for the next iteration - def update_fn(vals): - # This branch is taken if i < order - 1 - _h_phi_k, _fac = vals - next_fac = _fac * (i + 2) - next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac - return next_h_phi_k, next_fac - - current_h_phi_k, factorial_val = jax.lax.cond( - i < order - 1, - update_fn, # If true, update values - lambda vals: vals, # If false, pass through - (current_h_phi_k, factorial_val), - ) - return R, b, current_h_phi_k, factorial_val - - # Pre-allocate R and b to max size - R_init = jnp.zeros( - (self.config.solver_order, self.config.solver_order), dtype=h.dtype - ) - b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) - # Initialize loop carriers - init_h_phi_k = h_phi_1 / hh - 1.0 - init_factorial = 1.0 + h = lambda_t - lambda_s0 - R, b, _, _ = jax.lax.fori_loop( - 0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial) - ) + def rk_d1_loop_body(i, carry): + # Loop from i = 0 to order-1. + rks, D1s = carry - if len(D1s) > 0: - D1s = jnp.stack(D1s, axis=1) # (B, K, C, H, W) + # Get history from state buffer + history_idx = self.config.solver_order - (i + 2) + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[history_idx] # This is the actual timestep value - def solve_for_rhos(R_mat, b_vec, current_order): - # Create a mask to select the first `current_order` elements - mask = jnp.arange(self.config.solver_order) < current_order - mask_2d = mask[:, None] & mask[None, :] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)]) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) - # Pad R with identity and b with zeros to create a safe, full-sized system - R_safe = jnp.where( - mask_2d, R_mat, jnp.eye(self.config.solver_order, dtype=R_mat.dtype) - ) - b_safe = jnp.where(mask, b_vec, 0.0) + rk = (lambda_si - lambda_s0) / h + Di = (mi - m0) / rk - # Solve the full-size system and mask the result - solved_rhos = jnp.linalg.solve(R_safe, b_safe) - return jnp.where(mask, solved_rhos, 0.0) + # Update pre-allocated arrays + rks = rks.at[i].set(rk) + D1s = D1s.at[i].set(Di) + return rks, D1s - rhos_c_order1 = ( - jnp.zeros(self.config.solver_order, dtype=x_t.dtype).at[0].set(0.5) - ) - rhos_c_general = solve_for_rhos(R, b, order) - rhos_c = jnp.where(order == 1, rhos_c_order1, rhos_c_general) + # Pre-allocate arrays to max possible size + rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) + if self.config.solver_order == 1: + # Dummy D1s array. It will not be used if order == 1. This is for tracing. + D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) - D1_t = model_t - m0 + # Run the loop up to `order - 1` + rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init)) - corr_res = jax.lax.cond( - order > 1, - lambda _: (jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)), - lambda _: jnp.zeros_like(D1_t), - operand=None, - ) + rks = rks.at[order - 1].set(1.0) - final_rho = jnp.dot( - rhos_c, - jax.nn.one_hot(order - 1, self.config.solver_order, dtype=rhos_c.dtype), - ) + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) - if self.config.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - x_t = x_t_ - alpha_t * B_h * (corr_res + final_rho * D1_t) - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t) - - return x_t.astype(x.dtype) - - def index_for_timestep( - self, - state: UniPCMultistepSchedulerState, - timestep: Union[int, jnp.ndarray], - schedule_timesteps: Optional[jnp.ndarray] = None, - ) -> int: - """ "Gets the step_index for timestep.""" - if schedule_timesteps is None: - schedule_timesteps = state.timesteps - - # QUINN!! - # timestep_val = ( - # timestep.item() - # if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0 - # else timestep - # ) - timestep_val = timestep - - index_candidates = jnp.where( - schedule_timesteps == timestep_val, size=1, fill_value=-1 - )[0] - - step_index = jnp.where( - index_candidates[0] == -1, # No match found - len(schedule_timesteps) - 1, # Default to last index - index_candidates[0], - ) - return step_index - - def _init_step_index( - self, state: UniPCMultistepSchedulerState, timestep: Union[int, jnp.ndarray] - ) -> UniPCMultistepSchedulerState: - """Initializes the step_index counter for the scheduler.""" - if state.begin_index is None: - step_index_val = self.index_for_timestep(state, timestep) - return state.replace(step_index=step_index_val) - else: - return state.replace(step_index=state.begin_index) - - @partial(jax.jit, static_argnums=(0, 5)) # self is static_argnum=0 - def step( - self, - state: UniPCMultistepSchedulerState, - model_output: jnp.ndarray, # This is the direct output from the diffusion model (e.g., noise prediction) - timestep: Union[ - int, jnp.ndarray - ], # Current discrete timestep from the scheduler's sequence - sample: jnp.ndarray, # Current noisy sample (latent) - return_dict: bool = True, - generator: Optional[jax.random.PRNGKey] = None, # JAX random key - ) -> Union[FlaxUniPCMultistepSchedulerOutput, Tuple[jnp.ndarray]]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep UniPC. - """ - if state.timesteps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - timestep_scalar = jnp.array(timestep) - - # Initialize step_index if it's the first step - if state.step_index is None: - state = self._init_step_index(state, timestep_scalar) - - # Determine if corrector should be used - use_corrector = ( - (state.step_index > 0) - & ( - ~jnp.isin( - state.step_index - 1, jnp.array(self.config.disable_corrector) - ) - ) - & (state.last_sample is not None) - ) + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() - # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type - model_output_for_history = self.convert_model_output( - state, model_output, sample - ) + def rb_loop_body(i, carry): + # Loop from i = 0 to order-1 + R, b, current_h_phi_k, factorial_val = carry - # Apply corrector if applicable - sample = jax.lax.cond( - use_corrector, - lambda: self.multistep_uni_c_bh_update( - state=state, - this_model_output=model_output_for_history, - last_sample=state.last_sample, - this_sample=sample, - order=state.this_order, - ), - lambda: sample, - ) + R = R.at[i].set(jnp.power(rks, i)) + b = b.at[i].set(current_h_phi_k * factorial_val / B_h) - # Update history buffers (model_outputs and timestep_list) - # Shift existing elements to the left and add new one at the end. - # `state.model_outputs` and `state.timestep_list` are fixed-size arrays. - # Example: - # t0:[None,...,model_output0] - # t1:[None,..model_output0,model_output1] - # ... - # tn:[model_output0,model_output1,...,model_output_n] - def step_idx0_branch(): - updated_model_outputs_history = state.model_outputs.at[-1].set( - model_output_for_history - ) - updated_timestep_list_history = state.timestep_list.at[-1].set( - timestep_scalar - ) - return updated_model_outputs_history, updated_timestep_list_history - - def non_step_idx0_branch(): - updated_model_outputs_history = jnp.roll( - state.model_outputs, shift=-1, axis=0 - ) - updated_model_outputs_history = updated_model_outputs_history.at[-1].set( - model_output_for_history - ) - - updated_timestep_list_history = jnp.roll(state.timestep_list, shift=-1) - updated_timestep_list_history = updated_timestep_list_history.at[-1].set( - timestep_scalar - ) - return updated_model_outputs_history, updated_timestep_list_history - - updated_model_outputs_history, updated_timestep_list_history = jax.lax.cond( - state.step_index == 0, step_idx0_branch, non_step_idx0_branch - ) - state = state.replace( - model_outputs=updated_model_outputs_history, - timestep_list=updated_timestep_list_history, - ) + # Conditionally update phi_k and factorial for the next iteration + def update_fn(vals): + # This branch is taken if i < order - 1 + _h_phi_k, _fac = vals + next_fac = _fac * (i + 2) + next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac + return next_h_phi_k, next_fac - # Determine the order for the current step (warmup phase logic) - this_order = jnp.where( - self.config.lower_order_final, - jnp.minimum( - self.config.solver_order, len(state.timesteps) - state.step_index - ), - self.config.solver_order, - ) + current_h_phi_k, factorial_val = jax.lax.cond( + i < order - 1, + update_fn, # If true, update values + lambda vals: vals, # If false, pass through + (current_h_phi_k, factorial_val), + ) + return R, b, current_h_phi_k, factorial_val + + # Pre-allocate R and b to max size + R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype) + b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + + # Initialize loop carriers + init_h_phi_k = h_phi_1 / hh - 1.0 + init_factorial = 1.0 + + R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial)) + + if len(D1s) > 0: + D1s = jnp.stack(D1s, axis=1) # (B, K, C, H, W) + + def solve_for_rhos(R_mat, b_vec, current_order): + # Create a mask to select the first `current_order` elements + mask = jnp.arange(self.config.solver_order) < current_order + mask_2d = mask[:, None] & mask[None, :] + + # Pad R with identity and b with zeros to create a safe, full-sized system + R_safe = jnp.where(mask_2d, R_mat, jnp.eye(self.config.solver_order, dtype=R_mat.dtype)) + b_safe = jnp.where(mask, b_vec, 0.0) + + # Solve the full-size system and mask the result + solved_rhos = jnp.linalg.solve(R_safe, b_safe) + return jnp.where(mask, solved_rhos, 0.0) + + rhos_c_order1 = jnp.zeros(self.config.solver_order, dtype=x_t.dtype).at[0].set(0.5) + rhos_c_general = solve_for_rhos(R, b, order) + rhos_c = jnp.where(order == 1, rhos_c_order1, rhos_c_general) + + D1_t = model_t - m0 + + corr_res = jax.lax.cond( + order > 1, + lambda _: (jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)), + lambda _: jnp.zeros_like(D1_t), + operand=None, + ) + + final_rho = jnp.dot( + rhos_c, + jax.nn.one_hot(order - 1, self.config.solver_order, dtype=rhos_c.dtype), + ) + + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + final_rho * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t) + + return x_t.astype(x.dtype) + + def index_for_timestep( + self, + state: UniPCMultistepSchedulerState, + timestep: Union[int, jnp.ndarray], + schedule_timesteps: Optional[jnp.ndarray] = None, + ) -> int: + """ "Gets the step_index for timestep.""" + if schedule_timesteps is None: + schedule_timesteps = state.timesteps + + # QUINN!! + # timestep_val = ( + # timestep.item() + # if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0 + # else timestep + # ) + timestep_val = timestep + + index_candidates = jnp.where(schedule_timesteps == timestep_val, size=1, fill_value=-1)[0] + + step_index = jnp.where( + index_candidates[0] == -1, # No match found + len(schedule_timesteps) - 1, # Default to last index + index_candidates[0], + ) + return step_index + + def _init_step_index( + self, state: UniPCMultistepSchedulerState, timestep: Union[int, jnp.ndarray] + ) -> UniPCMultistepSchedulerState: + """Initializes the step_index counter for the scheduler.""" + if state.begin_index is None: + step_index_val = self.index_for_timestep(state, timestep) + return state.replace(step_index=step_index_val) + else: + return state.replace(step_index=state.begin_index) + + @partial(jax.jit, static_argnums=(0, 5)) # self is static_argnum=0 + def step( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, # This is the direct output from the diffusion model (e.g., noise prediction) + timestep: Union[int, jnp.ndarray], # Current discrete timestep from the scheduler's sequence + sample: jnp.ndarray, # Current noisy sample (latent) + return_dict: bool = True, + generator: Optional[jax.random.PRNGKey] = None, # JAX random key + ) -> Union[FlaxUniPCMultistepSchedulerOutput, Tuple[jnp.ndarray]]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + """ + if state.timesteps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + timestep_scalar = jnp.array(timestep) + + # Initialize step_index if it's the first step + if state.step_index is None: + state = self._init_step_index(state, timestep_scalar) - # Warmup for multistep: `this_order` can't exceed `lower_order_nums + 1` - new_this_order = jnp.minimum(this_order, state.lower_order_nums + 1) - state = state.replace(this_order=new_this_order) + # Determine if corrector should be used + use_corrector = ( + (state.step_index > 0) + & (~jnp.isin(state.step_index - 1, jnp.array(self.config.disable_corrector))) + & (state.last_sample is not None) + ) - # Store current sample as `last_sample` for the *next* step's corrector - state = state.replace(last_sample=sample) + # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type + model_output_for_history = self.convert_model_output(state, model_output, sample) - # UniP predictor step - prev_sample = self.multistep_uni_p_bh_update( + # Apply corrector if applicable + sample = jax.lax.cond( + use_corrector, + lambda: self.multistep_uni_c_bh_update( state=state, - model_output=model_output, - sample=sample, + this_model_output=model_output_for_history, + last_sample=state.last_sample, + this_sample=sample, order=state.this_order, - ) + ), + lambda: sample, + ) - # Update lower_order_nums for warmup - new_lower_order_nums = jnp.where( - state.lower_order_nums < self.config.solver_order, - state.lower_order_nums + 1, - state.lower_order_nums, - ) - state = state.replace(lower_order_nums=new_lower_order_nums) - # Upon completion, increase step index by one - state = state.replace(step_index=state.step_index + 1) - - # Return the updated sample and state - if not return_dict: - return (prev_sample, state) - - return FlaxUniPCMultistepSchedulerOutput(prev_sample=prev_sample, state=state) - - def scale_model_input( - self, state: UniPCMultistepSchedulerState, sample: jnp.ndarray, *args, **kwargs - ) -> jnp.ndarray: - """ - UniPC does not scale model input, so it returns the sample unchanged. - """ - return sample - - def add_noise( - self, - state: UniPCMultistepSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return add_noise_common(state.common, original_samples, noise, timesteps) - - def _sigma_to_alpha_sigma_t(self, sigma): - if self.config.use_flow_sigmas: - alpha_t = 1 - sigma - sigma_t = sigma - else: - alpha_t = 1 / ((sigma**2 + 1) ** 0.5) - sigma_t = sigma * alpha_t - - return alpha_t, sigma_t - - def __len__(self) -> int: - return self.config.num_train_timesteps + # Update history buffers (model_outputs and timestep_list) + # Shift existing elements to the left and add new one at the end. + # `state.model_outputs` and `state.timestep_list` are fixed-size arrays. + # Example: + # t0:[None,...,model_output0] + # t1:[None,..model_output0,model_output1] + # ... + # tn:[model_output0,model_output1,...,model_output_n] + def step_idx0_branch(): + updated_model_outputs_history = state.model_outputs.at[-1].set(model_output_for_history) + updated_timestep_list_history = state.timestep_list.at[-1].set(timestep_scalar) + return updated_model_outputs_history, updated_timestep_list_history + + def non_step_idx0_branch(): + updated_model_outputs_history = jnp.roll(state.model_outputs, shift=-1, axis=0) + updated_model_outputs_history = updated_model_outputs_history.at[-1].set(model_output_for_history) + + updated_timestep_list_history = jnp.roll(state.timestep_list, shift=-1) + updated_timestep_list_history = updated_timestep_list_history.at[-1].set(timestep_scalar) + return updated_model_outputs_history, updated_timestep_list_history + + updated_model_outputs_history, updated_timestep_list_history = jax.lax.cond( + state.step_index == 0, step_idx0_branch, non_step_idx0_branch + ) + state = state.replace( + model_outputs=updated_model_outputs_history, + timestep_list=updated_timestep_list_history, + ) + + # Determine the order for the current step (warmup phase logic) + this_order = jnp.where( + self.config.lower_order_final, + jnp.minimum(self.config.solver_order, len(state.timesteps) - state.step_index), + self.config.solver_order, + ) + + # Warmup for multistep: `this_order` can't exceed `lower_order_nums + 1` + new_this_order = jnp.minimum(this_order, state.lower_order_nums + 1) + state = state.replace(this_order=new_this_order) + + # Store current sample as `last_sample` for the *next* step's corrector + state = state.replace(last_sample=sample) + + # UniP predictor step + prev_sample = self.multistep_uni_p_bh_update( + state=state, + model_output=model_output, + sample=sample, + order=state.this_order, + ) + + # Update lower_order_nums for warmup + new_lower_order_nums = jnp.where( + state.lower_order_nums < self.config.solver_order, + state.lower_order_nums + 1, + state.lower_order_nums, + ) + state = state.replace(lower_order_nums=new_lower_order_nums) + # Upon completion, increase step index by one + state = state.replace(step_index=state.step_index + 1) + + # Return the updated sample and state + if not return_dict: + return (prev_sample, state) + + return FlaxUniPCMultistepSchedulerOutput(prev_sample=prev_sample, state=state) + + def scale_model_input(self, state: UniPCMultistepSchedulerState, sample: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: + """ + UniPC does not scale model input, so it returns the sample unchanged. + """ + return sample + + def add_noise( + self, + state: UniPCMultistepSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = sigma + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + def __len__(self) -> int: + return self.config.num_train_timesteps diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 4b48fe349..5eba28a29 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -23,12 +23,12 @@ from jax.sharding import Mesh from .. import pyconfig -from ..max_utils import ( - create_device_mesh, - get_flash_block_sizes -) +from ..max_utils import (create_device_mesh, get_flash_block_sizes) from ..models.wan.transformers.transformer_wan import ( - WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock, WanModel + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, + WanModel, ) from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection from ..models.normalization_flax import FP32LayerNorm @@ -36,10 +36,12 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + class WanTransformerTest(unittest.TestCase): + def setUp(self): WanTransformerTest.dummy_data = {} - + def test_rotary_pos_embed(self): batch_size = 1 channels = 16 @@ -48,11 +50,7 @@ def test_rotary_pos_embed(self): width = 160 hidden_states_shape = (batch_size, frames, height, width, channels) dummy_hidden_states = jnp.ones(hidden_states_shape) - wan_rot_embed = WanRotaryPosEmbed( - attention_head_dim=128, - patch_size=[1, 2, 2], - max_seq_len=1024 - ) + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) dummy_output = wan_rot_embed(dummy_hidden_states) assert dummy_output.shape == (1, 1, 75600, 64) @@ -60,11 +58,7 @@ def test_nnx_pixart_alpha_text_projection(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_caption = jnp.ones((1, 512, 4096)) - layer = NNXPixArtAlphaTextProjection( - rngs=rngs, - in_features=4096, - hidden_size=5120 - ) + layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) dummy_output = layer(dummy_caption) dummy_output.shape == (1, 512, 5120) @@ -73,11 +67,7 @@ def test_nnx_timestep_embedding(self): rngs = nnx.Rngs(key) dummy_sample = jnp.ones((1, 256)) - layer = NNXTimestepEmbedding( - rngs=rngs, - in_channels=256, - time_embed_dim=5120 - ) + layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) dummy_output = layer(dummy_sample) assert dummy_output.shape == (1, 5120) @@ -87,49 +77,42 @@ def test_fp32_layer_norm(self): batch_size = 1 dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) # expected same output shape with same dtype - layer = FP32LayerNorm( - rngs=rngs, - dim=5120, - eps=1e-6, - elementwise_affine=False - ) + layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) dummy_output = layer(dummy_hidden_states) - assert dummy_output.shape == dummy_hidden_states.shape + assert dummy_output.shape == dummy_hidden_states.shape def test_wan_time_text_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) batch_size = 1 - dim=5120 - time_freq_dim=256 - time_proj_dim=30720 - text_embed_dim=4096 + dim = 5120 + time_freq_dim = 256 + time_proj_dim = 30720 + text_embed_dim = 4096 layer = WanTimeTextImageEmbedding( - rngs=rngs, - dim=dim, - time_freq_dim=time_freq_dim, - time_proj_dim=time_proj_dim, - text_embed_dim=text_embed_dim + rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim ) - + dummy_timestep = jnp.ones(batch_size) encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(dummy_timestep, dummy_encoder_hidden_states) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( + dummy_timestep, dummy_encoder_hidden_states + ) assert temb.shape == (batch_size, dim) assert timestep_proj.shape == (batch_size, time_proj_dim) assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) - + def test_wan_block(self): key = jax.random.key(0) rngs = nnx.Rngs(key) pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config @@ -139,12 +122,12 @@ def test_wan_block(self): mesh = Mesh(devices_array, config.mesh_axes) - dim=5120 - ffn_dim=13824 - num_heads=40 - qk_norm="rms_norm_across_heads" - cross_attn_norm=True - eps=1e-6 + dim = 5120 + ffn_dim = 13824 + num_heads = 40 + qk_norm = "rms_norm_across_heads" + cross_attn_norm = True + eps = 1e-6 batch_size = 1 channels = 16 @@ -157,46 +140,40 @@ def test_wan_block(self): hidden_states_shape = (batch_size, frames, height, width, channels) dummy_hidden_states = jnp.ones(hidden_states_shape) - wan_rot_embed = WanRotaryPosEmbed( - attention_head_dim=128, - patch_size=[1, 2, 2], - max_seq_len=1024 - ) + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) assert dummy_rotary_emb.shape == (batch_size, 1, hidden_dim, 64) - + # for transformer block dummy_hidden_states = jnp.ones((batch_size, hidden_dim, dim)) - + dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) dummy_temb = jnp.ones((batch_size, 6, dim)) wan_block = WanTransformerBlock( - rngs=rngs, - dim=dim, - ffn_dim=ffn_dim, - num_heads=num_heads, - qk_norm=qk_norm, - cross_attn_norm=cross_attn_norm, - eps=eps, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, ) dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) assert dummy_output.shape == dummy_hidden_states.shape - - def test_wan_attention(self): pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config @@ -207,11 +184,7 @@ def test_wan_attention(self): width = 160 hidden_states_shape = (batch_size, frames, height, width, channels) dummy_hidden_states = jnp.ones(hidden_states_shape) - wan_rot_embed = WanRotaryPosEmbed( - attention_head_dim=128, - patch_size=[1, 2, 2], - max_seq_len=1024 - ) + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) key = jax.random.key(0) @@ -224,47 +197,47 @@ def test_wan_attention(self): batch_size = 1 query_dim = 5120 attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, ) - + dummy_hidden_states_shape = (batch_size, 75600, query_dim) dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) dummy_output = attention( - hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb ) assert dummy_output.shape == dummy_hidden_states_shape # dot product try: attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="dot_product", - split_head_dim=True, - mesh=mesh, - flash_block_sizes=flash_block_sizes, + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, ) - except NotImplementedError as e: + except NotImplementedError: pass - + def test_wan_model(self): pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config @@ -284,23 +257,21 @@ def test_wan_model(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 - query_dim = 5120 wan_model = WanModel( - rngs=rngs, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, + rngs=rngs, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, ) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) dummy_output = wan_model( - hidden_states=dummy_hidden_states, - timestep=dummy_timestep, - encoder_hidden_states=dummy_encoder_hidden_states + hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states ) assert dummy_output.shape == hidden_states_shape + if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index fe037c255..7b131e7fb 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -25,7 +25,7 @@ from jax.sharding import Mesh from .. import pyconfig from ..max_utils import ( - create_device_mesh, + create_device_mesh, ) import numpy as np import unittest @@ -258,11 +258,11 @@ def test_3d_conv(self): key = jax.random.key(0) rngs = nnx.Rngs(key) pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config devices_array = create_device_mesh(config) @@ -289,7 +289,7 @@ def test_3d_conv(self): kernel_size=(kernel_d, kernel_h, kernel_w), padding=(padding_d, padding_h, padding_w), rngs=rngs, # Pass rngs for initialization, - mesh=mesh + mesh=mesh, ) # --- Test Case 1: No Cache --- @@ -310,11 +310,11 @@ def test_wan_residual(self): key = jax.random.key(0) rngs = nnx.Rngs(key) pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config devices_array = create_device_mesh(config) @@ -329,12 +329,7 @@ def test_wan_residual(self): input_shape = (batch, t, height, width, dim) expected_output_shape = (batch, t, height, width, dim) - wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, - mesh=mesh - ) + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape @@ -344,12 +339,7 @@ def test_wan_residual(self): out_dim = 196 expected_output_shape = (batch, t, height, width, out_dim) - wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, - mesh=mesh - ) + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape @@ -372,11 +362,11 @@ def test_wan_midblock(self): key = jax.random.key(0) rngs = nnx.Rngs(key) pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config devices_array = create_device_mesh(config) @@ -396,11 +386,11 @@ def test_wan_decode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config devices_array = create_device_mesh(config) @@ -419,7 +409,7 @@ def test_wan_decode(self): num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, - mesh=mesh + mesh=mesh, ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 @@ -440,11 +430,11 @@ def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config devices_array = create_device_mesh(config) @@ -463,7 +453,7 @@ def test_wan_encode(self): num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, - mesh=mesh + mesh=mesh, ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 @@ -485,22 +475,17 @@ def vae_encode(video, wan_vae, vae_cache, key): key = jax.random.key(0) rngs = nnx.Rngs(key) pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, ) config = pyconfig.config devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh - ) + wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index 62d9ba859..8d4774987 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -21,11 +21,14 @@ from maxdiffusion import max_logging, pyconfig from maxdiffusion.train_utils import validate_train_config + def train(config): from maxdiffusion.trainers.wan_trainer import WanTrainer + trainer = WanTrainer(config) trainer.start_training() + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) config = pyconfig.config @@ -33,5 +36,6 @@ def main(argv: Sequence[str]) -> None: max_logging.log(f"Found {jax.device_count()} devices.") train(config) + if __name__ == "__main__": - app.run(main) \ No newline at end of file + app.run(main) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index c60d59b97..d9e1d5d4c 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -24,13 +24,12 @@ from flax import nnx from ..schedulers import FlaxEulerDiscreteScheduler from .. import max_utils, max_logging, train_utils, maxdiffusion_utils -from ..checkpointing.wan_checkpointer import ( - WanCheckpointer, - WAN_CHECKPOINT -) +from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) from multihost_dataloading import _form_global_array + class WanTrainer(WanCheckpointer): + def __init__(self, config): WanCheckpointer.__init__(self, config, WAN_CHECKPOINT) if config.train_text_encoder: @@ -52,7 +51,7 @@ def create_scheduler(self, pipeline, params): return noise_scheduler, noise_scheduler_state def calculate_tflops(self, pipeline): - max_logging.log(f"WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") + max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") return 0 def load_dataset(self, pipeline): @@ -73,11 +72,13 @@ def start_training(self): dummy_inputs = self.load_dataset(pipeline) mesh = pipeline.mesh optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) - dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs]) + dummy_inputs = tuple( + [jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs] + ) self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs) - + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): - + graphdef, state = nnx.split((pipeline.transformer, optimizer)) writer = max_utils.initialize_summary_writer(self.config) num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) @@ -90,12 +91,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}") max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") - - + state = state.to_pure_dict() p_train_step = jax.jit( - train_step, - donate_argnums=(0,), + train_step, + donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) start_step = 0 @@ -117,7 +117,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): max_utils.activate_profiler(self.config) with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh: state, train_metric, rng = p_train_step(state, graphdef, data, rng) - + new_time = datetime.datetime.now() if self.config.enable_profiler and step == last_profiling_step: @@ -130,38 +130,38 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) last_step_completion = new_time + def train_step(state, graphdef, data, rng): return step_optimizer(graphdef, state, data, rng) + def step_optimizer(graphdef, state, data, rng): _, new_rng = jax.random.split(rng) + def loss_fn(model): latents, prompt_embeds, timesteps = data - noise = jax.random.normal( - key=new_rng, - shape=latents.shape, - dtype=latents.dtype - ) + noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) # TODO - add noise here model_pred = model( - hidden_states=noise, - timestep=timesteps, - encoder_hidden_states=prompt_embeds, - is_uncond=jnp.array(False, dtype=jnp.bool_), - slg_mask=jnp.zeros(1, dtype=jnp.bool_) + hidden_states=noise, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + is_uncond=jnp.array(False, dtype=jnp.bool_), + slg_mask=jnp.zeros(1, dtype=jnp.bool_), ) target = noise - latents loss = (target - model_pred) ** 2 loss = jnp.mean(loss) - #breakpoint() + # breakpoint() return loss + model, optimizer = nnx.merge(graphdef, state) loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(grads) state = nnx.state((model, optimizer)) state = state.to_pure_dict() metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} - return state, metrics, new_rng \ No newline at end of file + return state, metrics, new_rng From 54946446599b0d613ff44bd5e9c7dae272962847 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 13 Jun 2025 20:22:42 +0000 Subject: [PATCH 51/54] halves inference time. --- src/maxdiffusion/configs/base_wan_14b.yml | 4 ++- src/maxdiffusion/generate_wan.py | 34 +++++++++++++++---- .../wan/transformers/transformer_wan.py | 14 ++++---- .../pipelines/wan/wan_pipeline.py | 3 +- 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index dda08817f..1dd81b075 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -112,8 +112,8 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], @@ -182,6 +182,8 @@ num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' per_device_batch_size: 1 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 5791d8a8f..057a6b25a 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -16,7 +16,7 @@ import jax import time from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline -from maxdiffusion import pyconfig +from maxdiffusion import pyconfig, max_logging from absl import app from maxdiffusion.utils import export_to_video @@ -30,9 +30,17 @@ def run(config): slg_layers = config.slg_layers slg_start = config.slg_start slg_end = config.slg_end + # If global_batch_size % jax.device_count is not 0, use FSDP sharding. + global_batch_size = config.global_batch_size + if global_batch_size != 0: + batch_multiplier = global_batch_size + else: + batch_multiplier = jax.device_count() * config.per_device_batch_size - prompt = [config.prompt] * jax.device_count() - negative_prompt = [config.negative_prompt] * jax.device_count() + prompt = [config.prompt] * batch_multiplier + negative_prompt = [config.negative_prompt] * batch_multiplier + + max_logging.log(f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}") videos = pipeline( prompt=prompt, @@ -51,6 +59,23 @@ def run(config): for i in range(len(videos)): export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) s0 = time.perf_counter() + videos = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, + ) + print("generation time: ", (time.perf_counter() - s0)) + for i in range(len(videos)): + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + + s0 = time.perf_counter() with jax.profiler.trace("/tmp/trace/"): videos = pipeline( prompt=prompt, @@ -65,9 +90,6 @@ def run(config): slg_end=slg_end, ) print("generation time: ", (time.perf_counter() - s0)) - for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) - def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index c79a21bf7..36da470bf 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -311,29 +311,29 @@ def __init__( def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + (self.scale_shift_table + temb), 6, axis=1 ) # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype( hidden_states.dtype ) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb ) - hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) + hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention - norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)) + norm_hidden_states = self.norm2(hidden_states) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype( hidden_states.dtype ) ff_output = self.ffn(norm_hidden_states) - hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( + hidden_states = (hidden_states + ff_output * c_gate_msa).astype( hidden_states.dtype ) return hidden_states @@ -485,7 +485,7 @@ def __call__( ) shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) - hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) + hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 80637f9a8..8d9a2986b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -21,6 +21,7 @@ import flax import flax.linen as nn from flax import nnx +from flax.linen import partitioning as nn_partitioning from ...pyconfig import HyperParameters from ... import max_logging from ... import max_utils @@ -420,7 +421,7 @@ def __call__( num_transformer_layers=self.transformer.config.num_layers, ) - with self.mesh: + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( graphdef=graphdef, sharded_state=state, From fc77dc05273d83eba7c9c9c3f2f9b20cece0cb3a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 16 Jun 2025 19:43:30 +0000 Subject: [PATCH 52/54] fix some tests. --- src/maxdiffusion/tests/attention_test.py | 19 ++++++++++---- .../tests/wan_transformer_test.py | 26 +++++++++++-------- src/maxdiffusion/trainers/wan_trainer.py | 3 ++- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index be4d44f2c..769d04b97 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -37,7 +37,10 @@ def setUp(self): def test_splash_attention(self): """Test numerics of splash attention are equivalent to dot_product""" - pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml")], unittest=True) + pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml"), + 'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,' + '"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,' + '"block_q_dq": 512, "block_kv_dq": 512}',], unittest=True) config = pyconfig.config batch = 8 @@ -47,7 +50,6 @@ def test_splash_attention(self): key1, key2 = jax.random.split(jax.random.PRNGKey(0)) x = jax.random.normal(key1, (batch, length, heads * head_depth)) - dot_product_attention = FlaxAttention( heads * head_depth, heads, @@ -55,7 +57,7 @@ def test_splash_attention(self): split_head_dim=True, attention_kernel="dot_product", mesh=None, - dtype=jnp.bfloat16, + dtype=jnp.bfloat16 ) params = dot_product_attention.init(key2, x)["params"] @@ -64,9 +66,16 @@ def test_splash_attention(self): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - + flash_block_sizes = max_utils.get_flash_block_sizes(config) splash_attention = FlaxAttention( - heads * head_depth, heads, head_depth, split_head_dim=True, attention_kernel="flash", mesh=mesh, dtype=jnp.bfloat16 + heads * head_depth, + heads, + head_depth, + split_head_dim=True, + attention_kernel="flash", + mesh=mesh, + dtype=jnp.bfloat16, + flash_block_sizes=flash_block_sizes ) params = splash_attention.init(key2, x)["params"] diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 5eba28a29..c25b111c2 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -163,8 +163,8 @@ def test_wan_block(self): mesh=mesh, flash_block_sizes=flash_block_sizes, ) - - dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) + with mesh: + dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) assert dummy_output.shape == dummy_hidden_states.shape def test_wan_attention(self): @@ -210,10 +210,10 @@ def test_wan_attention(self): dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) - - dummy_output = attention( - hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb - ) + with mesh: + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) assert dummy_output.shape == dummy_hidden_states_shape # dot product @@ -246,7 +246,7 @@ def test_wan_model(self): frames = 21 height = 90 width = 160 - hidden_states_shape = (batch_size, frames, height, width, channels) + hidden_states_shape = (batch_size, channels, frames, height, width) dummy_hidden_states = jnp.ones(hidden_states_shape) key = jax.random.key(0) @@ -266,10 +266,14 @@ def test_wan_model(self): dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) - - dummy_output = wan_model( - hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states - ) + with mesh: + dummy_output = wan_model( + hidden_states=dummy_hidden_states, + timestep=dummy_timestep, + encoder_hidden_states=dummy_encoder_hidden_states, + is_uncond=jnp.array(True, dtype=jnp.bool_), + slg_mask=jnp.zeros(40, dtype=jnp.bool_) + ) assert dummy_output.shape == hidden_states_shape diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index d9e1d5d4c..e107ac76e 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -22,6 +22,7 @@ import jax import jax.tree_util as jtu from flax import nnx +from flax.linen import partitioning as nn_partitioning from ..schedulers import FlaxEulerDiscreteScheduler from .. import max_utils, max_logging, train_utils, maxdiffusion_utils from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) @@ -115,7 +116,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh: + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): state, train_metric, rng = p_train_step(state, graphdef, data, rng) new_time = datetime.datetime.now() From 50a029d1bc64b4af18aa33eba0687a1e1ab2fcc2 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 16 Jun 2025 19:46:12 +0000 Subject: [PATCH 53/54] lint --- src/maxdiffusion/generate_wan.py | 5 ++++- .../models/wan/transformers/transformer_wan.py | 12 +++--------- src/maxdiffusion/tests/attention_test.py | 18 ++++++++++++------ src/maxdiffusion/tests/wan_transformer_test.py | 2 +- src/maxdiffusion/trainers/wan_trainer.py | 4 +++- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 057a6b25a..760d655cc 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -40,7 +40,9 @@ def run(config): prompt = [config.prompt] * batch_multiplier negative_prompt = [config.negative_prompt] * batch_multiplier - max_logging.log(f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}") + max_logging.log( + f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" + ) videos = pipeline( prompt=prompt, @@ -91,6 +93,7 @@ def run(config): ) print("generation time: ", (time.perf_counter() - s0)) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 36da470bf..a084447b6 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -315,9 +315,7 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t ) # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype( - hidden_states.dtype - ) + norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb ) @@ -329,13 +327,9 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype( - hidden_states.dtype - ) + norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) ff_output = self.ffn(norm_hidden_states) - hidden_states = (hidden_states + ff_output * c_gate_msa).astype( - hidden_states.dtype - ) + hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype) return hidden_states diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 769d04b97..3b013b791 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -37,10 +37,16 @@ def setUp(self): def test_splash_attention(self): """Test numerics of splash attention are equivalent to dot_product""" - pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml"), - 'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,' - '"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,' - '"block_q_dq": 512, "block_kv_dq": 512}',], unittest=True) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base21.yml"), + 'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,' + '"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,' + '"block_q_dq": 512, "block_kv_dq": 512}', + ], + unittest=True, + ) config = pyconfig.config batch = 8 @@ -57,7 +63,7 @@ def test_splash_attention(self): split_head_dim=True, attention_kernel="dot_product", mesh=None, - dtype=jnp.bfloat16 + dtype=jnp.bfloat16, ) params = dot_product_attention.init(key2, x)["params"] @@ -75,7 +81,7 @@ def test_splash_attention(self): attention_kernel="flash", mesh=mesh, dtype=jnp.bfloat16, - flash_block_sizes=flash_block_sizes + flash_block_sizes=flash_block_sizes, ) params = splash_attention.init(key2, x)["params"] diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index c25b111c2..e4684f043 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -272,7 +272,7 @@ def test_wan_model(self): timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states, is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=jnp.zeros(40, dtype=jnp.bool_) + slg_mask=jnp.zeros(40, dtype=jnp.bool_), ) assert dummy_output.shape == hidden_states_shape diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index e107ac76e..3740e2cf1 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -116,7 +116,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): state, train_metric, rng = p_train_step(state, graphdef, data, rng) new_time = datetime.datetime.now() From cc2c288a464b50571d44536fdd48c8545067a94c Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 16 Jun 2025 20:32:04 +0000 Subject: [PATCH 54/54] update tests. --- src/maxdiffusion/tests/wan_transformer_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index e4684f043..17741191a 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -17,6 +17,7 @@ import os import jax import jax.numpy as jnp +import pytest import unittest from absl.testing import absltest from flax import nnx @@ -34,6 +35,8 @@ from ..models.normalization_flax import FP32LayerNorm from ..models.attention_flax import FlaxWanAttention +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -81,6 +84,7 @@ def test_fp32_layer_norm(self): dummy_output = layer(dummy_hidden_states) assert dummy_output.shape == dummy_hidden_states.shape + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_time_text_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -231,6 +235,7 @@ def test_wan_attention(self): except NotImplementedError: pass + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_model(self): pyconfig.initialize( [