diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py new file mode 100644 index 000000000..5f64d4880 --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -0,0 +1,66 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from abc import ABC +from flax import nnx +from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) +from ..pipelines.wan.wan_pipeline import WanPipeline +from .. import max_logging, max_utils + +WAN_CHECKPOINT = "WAN_CHECKPOINT" + + +class WanCheckpointer(ABC): + + def __init__(self, config, checkpoint_type): + self.config = config + self.checkpoint_type = checkpoint_type + + self.checkpoint_manager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, + ) + + def _create_optimizer(self, model, config, learning_rate): + learning_rate_scheduler = max_utils.create_learning_rate_schedule( + learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps + ) + tx = max_utils.create_optimizer(config, learning_rate_scheduler) + return nnx.Optimizer(model, tx), learning_rate_scheduler + + def load_wan_configs_from_orbax(self, step): + max_logging.log("Restoring stable diffusion configs") + if step is None: + step = self.checkpoint_manager.latest_step() + if step is None: + return None + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None): + model_configs = self.load_wan_configs_from_orbax(step) + + if model_configs: + raise NotImplementedError("model configs should not exist in orbax") + else: + pipeline = self.load_diffusers_checkpoint() + + return pipeline diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_14b.yml similarity index 83% rename from src/maxdiffusion/configs/base_wan_t2v.yml rename to src/maxdiffusion/configs/base_wan_14b.yml index 002b9e73b..1dd81b075 100644 --- a/src/maxdiffusion/configs/base_wan_t2v.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -18,6 +18,10 @@ run_name: '' metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + gcs_metrics: False # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False @@ -25,18 +29,8 @@ log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' -# Flux params -flux_name: "flux-dev" -max_sequence_length: 512 -time_shift: True -base_shift: 0.5 -max_shift: 1.15 -# offloads t5 encoder after text encoding to save memory. -offload_encoders: True - - unet_checkpoint: '' -revision: 'refs/pr/95' +revision: '' # This will convert the weights to this dtype. # When running inference on TPUv5e, use weights_dtype: 'bfloat16' weights_dtype: 'bfloat16' @@ -59,24 +53,9 @@ split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: {} -# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. -# flash_block_sizes: { -# "block_q" : 1536, -# "block_kv_compute" : 1536, -# "block_kv" : 1536, -# "block_q_dkv" : 1536, -# "block_kv_dkv" : 1536, -# "block_kv_dkv_compute" : 1536, -# "block_q_dq" : 1536, -# "block_kv_dq" : 1536 -# } # GroupNorm groups norm_num_groups: 32 -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch -# else they will be loaded from pretrained_model_name_or_path -train_new_unet: False - # train text_encoder - Currently not supported for SDXL train_text_encoder: False text_encoder_learning_rate: 4.25e-6 @@ -133,15 +112,17 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -152,8 +133,8 @@ data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 # Dataset @@ -192,17 +173,19 @@ checkpoint_every: -1 enable_single_replica_ckpt_restoring: False # Training loop -learning_rate: 4.e-7 +learning_rate: 1.e-5 scale_lr: False max_train_samples: -1 # max_train_steps takes priority over num_train_epochs. -max_train_steps: 200 +max_train_steps: 1500 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' per_device_batch_size: 1 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 -warmup_steps_fraction: 0.0 +warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before @@ -212,7 +195,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 1.e-2 # AdamW Weight decay +adam_weight_decay: 0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False @@ -222,14 +205,25 @@ skip_first_n_steps_for_profiler: 5 profiler_steps: 10 # Generation parameters -prompt: "A magical castle in the middle of a forest, artistic drawing" -prompt_2: "A magical castle in the middle of a forest, artistic drawing" -negative_prompt: "purple, red" +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -guidance_scale: 3.5 +height: 480 +width: 832 +num_frames: 81 +guidance_scale: 5.0 +flow_shift: 3.0 + +# skip layer guidance +slg_layers: [9] +slg_start: 0.2 +slg_end: 1.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 50 +num_inference_steps: 30 +fps: 24 +save_final_checkpoint: False # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py new file mode 100644 index 000000000..760d655cc --- /dev/null +++ b/src/maxdiffusion/generate_wan.py @@ -0,0 +1,103 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence +import jax +import time +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion import pyconfig, max_logging +from absl import app +from maxdiffusion.utils import export_to_video + + +def run(config): + print("seed: ", config.seed) + pipeline = WanPipeline.from_pretrained(config) + s0 = time.perf_counter() + + # Skip layer guidance + slg_layers = config.slg_layers + slg_start = config.slg_start + slg_end = config.slg_end + # If global_batch_size % jax.device_count is not 0, use FSDP sharding. + global_batch_size = config.global_batch_size + if global_batch_size != 0: + batch_multiplier = global_batch_size + else: + batch_multiplier = jax.device_count() * config.per_device_batch_size + + prompt = [config.prompt] * batch_multiplier + negative_prompt = [config.negative_prompt] * batch_multiplier + + max_logging.log( + f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" + ) + + videos = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, + ) + + print("compile time: ", (time.perf_counter() - s0)) + for i in range(len(videos)): + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + s0 = time.perf_counter() + videos = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, + ) + print("generation time: ", (time.perf_counter() - s0)) + for i in range(len(videos)): + export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + + s0 = time.perf_counter() + with jax.profiler.trace("/tmp/trace/"): + videos = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, + ) + print("generation time: ", (time.perf_counter() - s0)) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 43400a62e..b9b1abdcb 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -287,6 +287,34 @@ def get_dummy_flux_inputs(config, pipeline, batch_size): return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) +def get_dummy_wan_inputs(config, pipeline, batch_size): + latents = pipeline.prepare_latents( + batch_size, + vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, + vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_channels_latents=pipeline.transformer.config.in_channels, + ) + bsz = latents.shape[0] + prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096)) + timesteps = jnp.array([0] * bsz, dtype=jnp.int32) + return (latents, prompt_embeds, timesteps) + + +def calculate_wan_tflops(config, pipeline, batch_size, rngs, train): + """ + Calculates jflux tflops. + batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't + cache the compilation when flash is enabled. + """ + (latents, prompt_embeds, timesteps) = get_dummy_wan_inputs(config, pipeline, batch_size) + return max_utils.calculate_model_tflops( + pipeline.transformer, + ) + + def calculate_flux_tflops(config, pipeline, batch_size, rngs, train): """ Calculates jflux tflops. diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 2f8946056..006614f87 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -14,9 +14,11 @@ import functools import math -from typing import Optional +from typing import Optional, Callable, Tuple import flax.linen as nn +from flax import nnx import jax +from jax.sharding import PartitionSpec import jax.numpy as jnp from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask @@ -42,274 +44,317 @@ Quant = quantizations.AqtQuantization -Quant = quantizations.AqtQuantization - - def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() -class AttentionOp(nn.Module): - mesh: Mesh - attention_kernel: str - scale: int - heads: int - dim_head: int - use_memory_efficient_attention: bool = False - split_head_dim: bool = False - float32_qk_product: bool = True - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - dtype: DType = jnp.float32 - quant: Quant = None +def _check_attention_inputs(query: Array, key: Array, value: Array) -> None: + """Check attention inputs.""" - def setup(self): - if self.attention_kernel == "cudnn_flash_te": - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + assert key.ndim == value.ndim, "k, v must have same rank." + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." + assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." + assert key.shape[-3] == value.shape[-3], "k, v lengths must match." + assert query.shape[-1] == key.shape[-1], "q, k depths must match." - self.dpa_layer = DotProductAttention( - head_dim=self.dim_head, - num_attention_heads=self.heads, - num_gqa_groups=self.heads, - attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' - attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - # attention_dropout=self.dropout_rate, - dropout_rng_name="aqt", - dtype=self.dtype, - # float32_logits=self.float32_logits, - qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=self.scale, - transpose_batch_sequence=False, - ) - def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: - """Check attention inputs.""" +def _reshape_data_from_cudnn_flash(tensor): + # reshapes from [b, s, h, d] back to [b, s, h * d] + return tensor.reshape(tensor.shape[0], tensor.shape[1], -1) - assert key.ndim == value.ndim, "k, v must have same rank." - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." - assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." - assert key.shape[-3] == value.shape[-3], "k, v lengths must match." - assert query.shape[-1] == key.shape[-1], "q, k depths must match." - def apply_attention(self, query: Array, key: Array, value: Array): - """Routes to different attention kernels.""" - self.check_attention_inputs(query, key, value) - - if self.attention_kernel == "flash": - can_use_flash_attention = ( - query.shape[1] >= self.flash_min_seq_length - and key.shape[1] >= self.flash_min_seq_length - and value.shape[1] >= self.flash_min_seq_length - ) - else: - can_use_flash_attention = True - - if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention: - return self.apply_attention_dot(query, key, value) - elif self.attention_kernel == "flash": - return self.tpu_flash_attention(query, key * self.scale, value) - elif self.attention_kernel == "cudnn_flash_te": - return self.cudnn_flash_attention(query, key, value) - else: - raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") +def _reshape_data_for_cudnn_flash(tensor, heads): + # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + return tensor - def tpu_flash_attention(self, query: jax.Array, key: jax.Array, value: jax.Array) -> jax.Array: - """TPU Flash Attention""" - query, kv_size = self.reshape_data_for_flash(query) - key, _ = self.reshape_data_for_flash(key) - value, _ = self.reshape_data_for_flash(value) +def _reshape_batch_dim_to_heads(tensor, heads): + batch_size, seq_len, dim = tensor.shape + head_size = heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor - axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) - @functools.partial( - shard_map.shard_map, - mesh=self.mesh, - in_specs=( - axis_names, - axis_names, - axis_names, - ), - out_specs=axis_names, - check_rep=False, - ) - def wrap_flash_attention(query, key, value): - if self.flash_block_sizes: - block_sizes = self.flash_block_sizes - else: - block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(512, query.shape[2]), - block_kv_compute=min(512, key.shape[2]), - block_kv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_kv_dkv=min(512, key.shape[2]), - block_kv_dkv_compute=min(512, query.shape[2]), - block_q_dq=min(512, query.shape[2]), - block_kv_dq=min(512, query.shape[2]), - ) - masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])] - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes - ) - return jax.vmap(splash_kernel)(query, key, value) - - devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] - # This warning might show up when doing model eval for example, when calculating model flops - # and that is expected. - if not (query.shape[0] / devices_in_data_fsdp).is_integer(): - max_logging.log( - "Warning, batch dimension should be shardable among the devices in data and fsdp" - f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" - ) - x = wrap_flash_attention(query, key, value) - x = x[:, :, :, :kv_size] - x = self.reshape_heads_to_head_dim(x) +def _reshape_heads_to_batch_dim(tensor, heads): + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + head_size = heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + else: + batch_size, head_size, seq_len, head_dim = tensor.shape + tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) - return x + return tensor - def cudnn_flash_attention( - self, - query: Array, - key: Array, - value: Array, - ) -> Array: - """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA - 2. Supports head_dim till 128; head_dim=256 support will be added soon - """ - # These imports are only meant to work in a GPU build. - # copied from tpu_flash_attention - query = self.reshape_data_for_cudnn_flash(query) - key = self.reshape_data_for_cudnn_flash(key) - value = self.reshape_data_for_cudnn_flash(value) - - cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) - axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) - - query = nn.with_logical_constraint(query, axis_names) - key = nn.with_logical_constraint(key, axis_names) - value = nn.with_logical_constraint(value, axis_names) - - @functools.partial( - shard_map.shard_map, - mesh=self.mesh, - in_specs=(axis_names, axis_names, axis_names), - out_specs=axis_names, - check_rep=False, + +def _reshape_heads_to_head_dim(tensor): + # takes a tensor of shape [b, h, s, d] and reshapes to [b, s, h * d] + # This is used to transform the output of flash attention back into the format of other attention outputs + b, h, s, d = tensor.shape + tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) + return jnp.reshape(tensor, (b, -1, h * d)) + + +def _unflatten_heads(tensor, heads): + # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + # Transpose to ('batch', 'heads', 'length', 'kv') + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + return tensor + + +def _reshape_data_for_flash(tensor, heads, flash_block_size): + """ + Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + """ + if tensor.ndim != 4: + tensor = _unflatten_heads(tensor, heads) + + # pad head_dim to 128 if less than that. + kv_size = tensor.shape[-1] + head_dim_pad = 0 + if kv_size < 128: + head_dim_pad = 128 - kv_size + + # pad seq_len to a multiple of flash_block_size if needed. + seq_len = tensor.shape[2] + # remainder + rem = seq_len % flash_block_size + seq_len_pad = 0 + if rem != 0: + # multiplier + mul = seq_len // flash_block_size + # pad to the closest multiplier of flash_block_size + seq_len_pad = (mul + 1) * flash_block_size - seq_len + + if kv_size < 128 or rem != 0: + npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) + tensor = jnp.pad(tensor, npad) + + return tensor, kv_size, seq_len + + +def _tpu_flash_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + flash_axis_names: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, +) -> jax.Array: + """TPU Flash Attention""" + + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + if flash_block_sizes: + block_sizes = flash_block_sizes + else: + block_sizes = splash_attention_kernel.BlockSizes( + block_q=min(max_block_size, query.shape[2]), + block_kv_compute=min(max_block_size, key.shape[2]), + block_kv=min(max_block_size, key.shape[2]), + block_q_dkv=min(max_block_size, query.shape[2]), + block_kv_dkv=min(max_block_size, key.shape[2]), + block_kv_dkv_compute=min(max_block_size, query.shape[2]), + block_q_dq=min(max_block_size, query.shape[2]), + block_kv_dq=min(max_block_size, query.shape[2]), ) - def wrap_flash_attention(query, key, value): - return jax.vmap(self.dpa_layer)(query, key, value, mask=None) - - out = wrap_flash_attention(query, key, value) - return self.reshape_data_from_cudnn_flash(out) - - def apply_attention_dot(self, query: Array, key: Array, value: Array): - """Apply Attention.""" - if self.split_head_dim: - b = key.shape[0] - query_states = jnp.reshape(query, (b, -1, self.heads, self.dim_head)) - key_states = jnp.reshape(key, (b, -1, self.heads, self.dim_head)) - value_states = jnp.reshape(value, (b, -1, self.heads, self.dim_head)) + + query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q) + key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute) + value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute) + + axis_names = nn.logical_to_mesh_axes(flash_axis_names) + + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=( + axis_names, + axis_names, + axis_names, + ), + out_specs=axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value): + masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])] + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes + ) + return jax.vmap(splash_kernel)(query, key, value) + + devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] + # This warning might show up when doing model eval for example, when calculating model flops + # and that is expected. + if not (query.shape[0] / devices_in_data_fsdp).is_integer(): + max_logging.log( + "Warning, batch dimension should be shardable among the devices in data and fsdp" + f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" + ) + x = wrap_flash_attention(query, key, value) + x = x[:, :, :query_seq_len, :kv_size] + x = _reshape_heads_to_head_dim(x) + + return x + + +def _apply_attention_dot( + query: Array, + key: Array, + value: Array, + dtype: jnp.dtype, + heads: int, + dim_head: int, + scale: float, + split_head_dim: bool, + float32_qk_product: bool, + use_memory_efficient_attention: bool, +): + """Apply Attention.""" + if split_head_dim: + b = key.shape[0] + query_states = jnp.reshape(query, (b, -1, heads, dim_head)) + key_states = jnp.reshape(key, (b, -1, heads, dim_head)) + value_states = jnp.reshape(value, (b, -1, heads, dim_head)) + else: + query_states = _reshape_heads_to_batch_dim(query, heads) + key_states = _reshape_heads_to_batch_dim(key, heads) + value_states = _reshape_heads_to_batch_dim(value, heads) + + if float32_qk_product: + query_states = query_states.astype(jnp.float32) + key_states = key_states.astype(jnp.float32) + + if use_memory_efficient_attention: + query_states = query_states.transpose(1, 0, 2) + key_states = key_states.transpose(1, 0, 2) + value_states = value_states.transpose(1, 0, 2) + + # this if statement create a chunk size for each layer of the unet + # the chunk size is equal to the query_length dimension of the deepest layer of the unet + + flatten_latent_dim = query_states.shape[-3] + if flatten_latent_dim % 64 == 0: + query_chunk_size = int(flatten_latent_dim / 64) + elif flatten_latent_dim % 16 == 0: + query_chunk_size = int(flatten_latent_dim / 16) + elif flatten_latent_dim % 4 == 0: + query_chunk_size = int(flatten_latent_dim / 4) else: - query_states = self.reshape_heads_to_batch_dim(query) - key_states = self.reshape_heads_to_batch_dim(key) - value_states = self.reshape_heads_to_batch_dim(value) - - if self.float32_qk_product: - query_states = query_states.astype(jnp.float32) - key_states = key_states.astype(jnp.float32) - - if self.use_memory_efficient_attention: - query_states = query_states.transpose(1, 0, 2) - key_states = key_states.transpose(1, 0, 2) - value_states = value_states.transpose(1, 0, 2) - - # this if statement create a chunk size for each layer of the unet - # the chunk size is equal to the query_length dimension of the deepest layer of the unet - - flatten_latent_dim = query_states.shape[-3] - if flatten_latent_dim % 64 == 0: - query_chunk_size = int(flatten_latent_dim / 64) - elif flatten_latent_dim % 16 == 0: - query_chunk_size = int(flatten_latent_dim / 16) - elif flatten_latent_dim % 4 == 0: - query_chunk_size = int(flatten_latent_dim / 4) - else: - query_chunk_size = int(flatten_latent_dim) - - hidden_states = jax_memory_efficient_attention( - query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 - ) + query_chunk_size = int(flatten_latent_dim) + + hidden_states = jax_memory_efficient_attention( + query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + ) - hidden_states = hidden_states.transpose(1, 0, 2) + hidden_states = hidden_states.transpose(1, 0, 2) + else: + if split_head_dim: + attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) else: - if self.split_head_dim: - attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) - else: - attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) - attention_scores = attention_scores * self.scale - attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2) + attention_scores = attention_scores * scale + attention_probs = nn.softmax(attention_scores, axis=-1 if split_head_dim else 2) - attention_probs = attention_probs.astype(self.dtype) + attention_probs = attention_probs.astype(dtype) - # attend to values - if self.split_head_dim: - hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) - b = hidden_states.shape[0] - hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head)) - else: - hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # attend to values + if split_head_dim: + hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) + b = hidden_states.shape[0] + hidden_states = jnp.reshape(hidden_states, (b, -1, heads * dim_head)) + else: + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + hidden_states = _reshape_batch_dim_to_heads(hidden_states, heads) - return hidden_states + return hidden_states - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def reshape_data_for_cudnn_flash(self, tensor): - # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) - batch, seq, heads_and_dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads) - return tensor - - def reshape_data_from_cudnn_flash(self, tensor): - # reshapes from [b, s, h, d] back to [b, s, h * d] - return tensor.reshape(tensor.shape[0], tensor.shape[1], -1) - - def reshape_data_for_flash(self, tensor): - # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) - batch, seq, heads_and_dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads) - # Transpose to ('batch', 'heads', 'length', 'kv') - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - kv_size = tensor.shape[-1] - if kv_size < 128: - npad = ((0, 0), (0, 0), (0, 0), (0, 128 - kv_size)) - tensor = jnp.pad(tensor, npad) - return tensor, kv_size +def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, mesh: Mesh, dpa_layer: Callable) -> Array: + """CUDNN Flash Attention with Transformer Engine. + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon + """ + # These imports are only meant to work in a GPU build. + # copied from tpu_flash_attention + query = _reshape_data_for_cudnn_flash(query, heads) + key = _reshape_data_for_cudnn_flash(key, heads) + value = _reshape_data_for_cudnn_flash(value, heads) + + cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) + axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) + + query = nn.with_logical_constraint(query, axis_names) + key = nn.with_logical_constraint(key, axis_names) + value = nn.with_logical_constraint(value, axis_names) + + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=(axis_names, axis_names, axis_names), + out_specs=axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value): + return jax.vmap(dpa_layer)(query, key, value, mask=None) + + out = wrap_flash_attention(query, key, value) + return _reshape_data_from_cudnn_flash(out) + + +def _apply_attention( + query: Array, + key: Array, + value: Array, + heads: int, + dim_head: int, + split_head_dim: bool, + float32_qk_product: bool, + attention_kernel: str, + flash_min_seq_length: int, + use_memory_efficient_attention: bool, + scale: float, + dtype: jnp.dtype, + mesh: Mesh, + flash_axis_names: AxisNames, + flash_block_sizes: BlockSizes, + dpa_layer: Callable, +): + """Routes to different attention kernels.""" + _check_attention_inputs(query, key, value) + seq_len_idx = 1 + if query.ndim == 4: + seq_len_idx = 2 + if attention_kernel == "flash": + can_use_flash_attention = ( + query.shape[seq_len_idx] >= flash_min_seq_length + and key.shape[seq_len_idx] >= flash_min_seq_length + and value.shape[seq_len_idx] >= flash_min_seq_length + ) + else: + can_use_flash_attention = True - def reshape_heads_to_head_dim(self, tensor): - # takes a tensor of shape [b, h, s, d] and reshapes to [b, s, h * d] - # This is used to transform the output of flash attention back into the format of other attention outputs - b, h, s, d = tensor.shape - tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) - return jnp.reshape(tensor, (b, -1, h * d)) + if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: + return _apply_attention_dot( + query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention + ) + elif attention_kernel == "flash": + return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype) + elif attention_kernel == "cudnn_flash_te": + return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) + else: + raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): @@ -416,130 +461,288 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) -class FlaxWanAttention(nn.Module): - query_dim: int - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 +class NNXAttentionOp(nnx.Module): + + def __init__( + self, + mesh: Mesh, + attention_kernel: str, + scale: int, + heads: int, + dim_head: int, + use_memory_efficient_attention: bool = False, + split_head_dim: bool = False, + float32_qk_product: bool = True, + flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + dtype: DType = jnp.float32, + quant: Quant = None, + ): + self.dpa_layer = None + if attention_kernel == "cudnn_flash_te": + raise NotImplementedError(f"{self} has not been tested with {attention_kernel}") + + self.mesh = mesh + self.scale = scale + self.heads = heads + self.dim_head = dim_head + self.attention_kernel = attention_kernel + self.use_memory_efficient_attention = use_memory_efficient_attention + self.split_head_dim = split_head_dim + self.float32_qk_product = float32_qk_product + self.flash_axis_names = flash_axis_names + self.flash_min_seq_length = flash_min_seq_length + self.flash_block_sizes = flash_block_sizes + self.dtype = dtype + self.quant = quant + + def apply_attention(self, query: Array, key: Array, value: Array): + return _apply_attention( + query=query, + key=key, + value=value, + heads=self.heads, + dim_head=self.dim_head, + split_head_dim=self.split_head_dim, + float32_qk_product=self.float32_qk_product, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + scale=self.scale, + dtype=self.dtype, + mesh=self.mesh, + flash_axis_names=self.flash_axis_names, + flash_block_sizes=self.flash_block_sizes, + dpa_layer=self.dpa_layer, + ) + + +class AttentionOp(nn.Module): + mesh: Mesh + attention_kernel: str + scale: int + heads: int + dim_head: int use_memory_efficient_attention: bool = False split_head_dim: bool = False - attention_kernel: str = "dot_product" + float32_qk_product: bool = True + flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) - precision: jax.lax.Precision = None - qkv_bias: bool = False + dtype: DType = jnp.float32 + quant: Quant = None def setup(self): - if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: - raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") - inner_dim = self.dim_head * self.heads - scale = self.dim_head**-0.5 + self.dpa_layer = None + if self.attention_kernel == "cudnn_flash_te": + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - self.attention_op = AttentionOp( - mesh=self.mesh, - attention_kernel=self.attention_kernel, - scale=scale, + self.dpa_layer = DotProductAttention( + head_dim=self.dim_head, + num_attention_heads=self.heads, + num_gqa_groups=self.heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + # float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=self.scale, + transpose_batch_sequence=False, + ) + + def apply_attention(self, query: Array, key: Array, value: Array): + return _apply_attention( + query=query, + key=key, + value=value, heads=self.heads, dim_head=self.dim_head, + split_head_dim=self.split_head_dim, + float32_qk_product=self.float32_qk_product, + attention_kernel=self.attention_kernel, flash_min_seq_length=self.flash_min_seq_length, use_memory_efficient_attention=self.use_memory_efficient_attention, - split_head_dim=self.split_head_dim, - flash_block_sizes=self.flash_block_sizes, + scale=self.scale, dtype=self.dtype, + mesh=self.mesh, + flash_axis_names=self.flash_axis_names, + flash_block_sizes=self.flash_block_sizes, + dpa_layer=self.dpa_layer, + ) + + +class FlaxWanAttention(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + eps: float = 1e-6, + qk_norm: str = "rms_norm_across_heads", + use_memory_efficient_attention: bool = False, + split_head_dim: bool = False, + attention_kernel: str = "flash", + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD), + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + precision: jax.lax.Precision = None, + qkv_bias: bool = False, + quant: Quant = None, + ): + if attention_kernel == "cudnn_flash_te": + raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") + + if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + self.dim_head = dim_head + self.heads = heads + self.inner_dim = dim_head * heads + scale = dim_head**-0.5 + self.qk_norm = qk_norm + self.query_axis_names = query_axis_names + self.key_axis_names = key_axis_names + self.value_axis_names = value_axis_names + self.out_axis_names = out_axis_names + + self.attention_op = NNXAttentionOp( + mesh=mesh, + attention_kernel=attention_kernel, + scale=scale, + heads=heads, + dim_head=dim_head, + use_memory_efficient_attention=use_memory_efficient_attention, + split_head_dim=split_head_dim, float32_qk_product=False, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + dtype=dtype, + quant=quant, ) kernel_axes = ("embed", "heads") - qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) + qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) - qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "heads")) - - self.query = nn.Dense( - inner_dim, + self.query = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_q", - precision=self.precision, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.key = nn.Dense( - inner_dim, + self.key = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_k", - precision=self.precision, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.value = nn.Dense( - inner_dim, + self.value = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, kernel_init=qkv_init_kernel, - use_bias=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_v", - precision=self.precision, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.query_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, - ) - self.key_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, + self.proj_attn = nnx.Linear( + rngs=rngs, + in_features=self.inner_dim, + out_features=self.inner_dim, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) - self.proj_attn = nn.Dense( - self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("heads", "embed")), - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="to_out_0", - precision=self.precision, - ) - self.dropout_layer = nn.Dropout(rate=self.dropout) + self.norm_q = None + self.norm_k = None + if qk_norm is not None: + self.norm_q = nnx.RMSNorm( + num_features=self.inner_dim, + rngs=rngs, + epsilon=eps, + dtype=dtype, + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + param_dtype=weights_dtype, + ) - def call( - self, - hidden_states: Array, - encoder_hidden_states: Optional[Array], - rotary_emb: Optional[Array], - deterministic: bool = True, - ): - encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + self.norm_k = nnx.RMSNorm( + num_features=self.inner_dim, + rngs=rngs, + dtype=dtype, + scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + param_dtype=weights_dtype, + ) + + def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]: + dtype = xq.dtype + reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) + reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + + xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) + xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + + xq_out_complex = xq_ * freqs_cis + xk_out_complex = xk_ * freqs_cis + + xq_out = jnp.stack([jnp.real(xq_out_complex), jnp.imag(xq_out_complex)], axis=-1).reshape(xq.shape).astype(dtype) + xk_out = jnp.stack([jnp.real(xk_out_complex), jnp.imag(xk_out_complex)], axis=-1).reshape(xk.shape).astype(dtype) + + return xq_out, xk_out + + def __call__( + self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None + ) -> jax.Array: + hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) + dtype = hidden_states.dtype + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states query_proj = self.query(hidden_states) key_proj = self.key(encoder_hidden_states) value_proj = self.value(encoder_hidden_states) - query_proj = self.query_norm(query_proj) - key_proj = self.key_norm(key_proj) + if self.qk_norm: + query_proj = self.norm_q(query_proj) + key_proj = self.norm_k(key_proj) + if rotary_emb is not None: + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) + value_proj = _unflatten_heads(value_proj, self.heads) + query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", "tensor", None, None)) + key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", "tensor", None, None)) + value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", "tensor", None, None)) - if rotary_emb: - query_proj, key_proj = self.apply_rope(query_proj, key_proj, rotary_emb) - - query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) - key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) - value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None)) - hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + attn_output = attn_output.astype(dtype=dtype) - hidden_states = self.proj_attn(hidden_states) - hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) - return self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.proj_attn(attn_output) + return hidden_states class FlaxFluxAttention(nn.Module): diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index cc961e131..d994b46e7 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import math - +from typing import Optional import flax.linen as nn +from flax import nnx import jax.numpy as jnp from typing import List, Union import jax +from .modeling_flax_utils import get_activation def get_sinusoidal_embeddings( @@ -57,6 +59,100 @@ def get_sinusoidal_embeddings( return signal +class NNXTimestepEmbedding(nnx.Module): + r""" + Time step Embedding Module. Learns embeddings for input time steps. + + Args: + time_embed_dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int, + time_embed_dim: int = 32, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim: int = None, + sample_proj_bias=True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + self.linear_1 = nnx.Linear( + rngs=rngs, + in_features=in_channels, + out_features=time_embed_dim, + use_bias=sample_proj_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "embed", + "mlp", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + + if cond_proj_dim is not None: + self.cond_proj = nnx.Linear( + rngs=rngs, + ) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + + self.linear_2 = nnx.Linear( + rngs=rngs, + in_features=time_embed_dim, + out_features=time_embed_dim_out, + use_bias=sample_proj_bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "mlp", + "embed", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def __call__(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. @@ -80,6 +176,26 @@ def __call__(self, temb): return temb +class NNXFlaxTimesteps(nnx.Module): + + def __init__( + self, + dim: int = 32, + flip_sin_to_cos: bool = False, + freq_shift: float = 1.0, + scale: int = 1, + ): + self.dim = dim + self.flip_sin_to_cos = flip_sin_to_cos + self.freq_shift = freq_shift + self.scale = scale + + def __call__(self, timesteps): + return get_sinusoidal_embeddings( + timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift + ) + + class FlaxTimesteps(nn.Module): r""" Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 @@ -102,7 +218,13 @@ def __call__(self, timesteps): def get_1d_rotary_pos_embed( - dim: int, pos: Union[jnp.array, int], theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32 + dim: int, + pos: Union[jnp.array, int], + theta: float = 10000.0, + linear_factor=1.0, + ntk_factor=1.0, + freqs_dtype=jnp.float32, + use_real: bool = True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -115,13 +237,77 @@ def get_1d_rotary_pos_embed( theta = theta * ntk_factor freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor freqs = jnp.outer(pos, freqs) - freqs_cos = jnp.cos(freqs) - freqs_sin = jnp.sin(freqs) - out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1) - + if use_real: + # Flux + freqs_cos = jnp.cos(freqs) + freqs_sin = jnp.sin(freqs) + out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1) + else: + # Wan 2.1 + out = jnp.exp(1j * freqs) return out +class NNXPixArtAlphaTextProjection(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + in_features: int, + hidden_size: int, + out_features: int = None, + act_fn: str = "gelu_tanh", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + if out_features is None: + out_features = hidden_size + + self.linear_1 = nnx.Linear( + rngs=rngs, + in_features=in_features, + out_features=hidden_size, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "embed", + "mlp", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + self.act_1 = get_activation(act_fn) + + self.linear_2 = nnx.Linear( + rngs=rngs, + in_features=hidden_size, + out_features=out_features, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "mlp", + "embed", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + def __call__(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index b93ba8396..5e08a2eb8 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -42,6 +42,22 @@ logger = logging.get_logger(__name__) +# gelu and gelu_tanh both use approximate=True by default +_ACTIVATIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + "relu": jax.nn.relu, + "gelu": jax.nn.gelu, + "gelu_tanh": jax.nn.gelu, + "mish": jax.nn.mish, +} + + +def get_activation(name: str): + func = _ACTIVATIONS.get(name) + if func is None: + raise ValueError(f"Unknown activation function: {name}") + return func class FlaxModelMixin(PushToHubMixin): diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index ea3b970d8..2ba658d4b 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp import flax.linen as nn +from flax import nnx class AdaLayerNormContinuous(nn.Module): @@ -147,3 +148,21 @@ def __call__(self, x, emb): else: raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") return x, gate_msa + + +class FP32LayerNorm(nnx.Module): + + def __init__(self, rngs: nnx.Rngs, dim: int, eps: float, elementwise_affine: bool): + self.layer_norm = nnx.LayerNorm( + rngs=rngs, + num_features=dim, + epsilon=eps, + use_bias=elementwise_affine, + use_scale=elementwise_affine, + param_dtype=jnp.float32, + dtype=jnp.float32, + ) + + def __call__(self, inputs: jax.Array) -> jax.Array: + origin_dtype = inputs.dtype + return self.layer_norm(inputs.astype(dtype=jnp.float32)).astype(dtype=origin_dtype) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 8325c3707..19244f723 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -20,7 +20,7 @@ import jax.numpy as jnp from flax import nnx from ...configuration_utils import ConfigMixin -from ..modeling_flax_utils import FlaxModelMixin +from ..modeling_flax_utils import FlaxModelMixin, get_activation from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) @@ -28,15 +28,6 @@ CACHE_T = 2 -_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish} - - -def get_activation(name: str): - func = _ACTIVATIONS.get(name) - if func is None: - raise ValueError(f"Unknown activation function: {name}") - return func - # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: @@ -60,6 +51,10 @@ def __init__( stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, use_bias: bool = True, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") self.stride = _canonicalize_tuple(stride, 3, "stride") @@ -76,6 +71,12 @@ def __init__( # Store the amount of padding needed *before* the depth dimension for caching logic self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + # Set sharding dynamically based on out_channels. + num_fsdp_axis_devices = mesh.device_ids.shape[1] + kernel_sharding = (None, None, None, None, None) + if out_channels % num_fsdp_axis_devices == 0: + kernel_sharding = (None, None, None, None, "conv_out") + self.conv = nnx.Conv( in_features=in_channels, out_features=out_channels, @@ -84,6 +85,10 @@ def __init__( use_bias=use_bias, padding="VALID", # Handle padding manually rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), kernel_sharding), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ) def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: @@ -184,8 +189,23 @@ def __init__( rngs: nnx.Rngs, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): - self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs) + self.conv = nnx.Conv( + dim, + dim, + kernel_size=kernel_size, + strides=stride, + use_bias=True, + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, None)), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + ) def __call__(self, x): return self.conv(x) @@ -198,6 +218,10 @@ def __init__( dim: int, mode: str, rngs: nnx.Rngs, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.dim = dim self.mode = mode @@ -213,6 +237,10 @@ def __init__( padding="SAME", use_bias=True, rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ), ) elif mode == "upsample3d": @@ -225,6 +253,10 @@ def __init__( padding="SAME", use_bias=True, rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, ), ) self.time_conv = WanCausalConv3d( @@ -233,13 +265,44 @@ def __init__( out_channels=dim * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) elif mode == "downsample2d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(3, 3), + stride=(2, 2), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) elif mode == "downsample3d": - self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2)) + self.resample = ZeroPaddedConv2D( + dim=dim, + rngs=rngs, + kernel_size=(3, 3), + stride=(2, 2), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) self.time_conv = WanCausalConv3d( - rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + rngs=rngs, + in_channels=dim, + out_channels=dim, + kernel_size=(3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0), + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) else: self.resample = Identity() @@ -301,16 +364,49 @@ def __init__( rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.nonlinearity = get_activation(non_linearity) # layers self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) - self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.conv1 = WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) - self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1) + self.conv2 = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=out_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) self.conv_shortcut = ( - WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1) + WanCausalConv3d( + rngs=rngs, + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) if in_dim != out_dim else Identity() ) @@ -353,11 +449,37 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): class WanAttentionBlock(nnx.Module): - def __init__(self, dim: int, rngs: nnx.Rngs): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): self.dim = dim self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) - self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs) - self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs) + self.to_qkv = nnx.Conv( + in_features=dim, + out_features=dim * 3, + kernel_size=(1, 1), + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + ) + self.proj = nnx.Conv( + in_features=dim, + out_features=dim, + kernel_size=(1, 1), + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "conv_in", None)), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + ) def __call__(self, x: jax.Array): @@ -371,7 +493,6 @@ def __call__(self, x: jax.Array): # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - # q, k, v = jnp.split(qkv, 3, axis=-1) q, k, v = jnp.split(qkv, 3, axis=-2) q = jnp.transpose(q, (0, 1, 3, 2)) k = jnp.transpose(k, (0, 1, 3, 2)) @@ -389,13 +510,50 @@ def __call__(self, x: jax.Array): class WanMidBlock(nnx.Module): - def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): self.dim = dim - resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)] + resnets = [ + WanResidualBlock( + in_dim=dim, + out_dim=dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) + ] attentions = [] for _ in range(num_layers): - attentions.append(WanAttentionBlock(dim=dim, rngs=rngs)) - resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity)) + attentions.append( + WanAttentionBlock(dim=dim, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision) + ) + resnets.append( + WanResidualBlock( + in_dim=dim, + out_dim=dim, + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) + ) self.attentions = attentions self.resnets = resnets @@ -419,6 +577,10 @@ def __init__( dropout: float = 0.0, upsample_mode: Optional[str] = None, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): # Create layers list resnets = [] @@ -426,7 +588,17 @@ def __init__( current_dim = in_dim for _ in range(num_res_blocks + 1): resnets.append( - WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs) + WanResidualBlock( + in_dim=current_dim, + out_dim=out_dim, + dropout=dropout, + non_linearity=non_linearity, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) ) current_dim = out_dim self.resnets = resnets @@ -434,7 +606,17 @@ def __init__( # Add upsampling layer if needed. self.upsamplers = None if upsample_mode is not None: - self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)] + self.upsamplers = [ + WanResample( + dim=out_dim, + mode=upsample_mode, + rngs=rngs, + mesh=mesh, + weights_dtype=weights_dtype, + dtype=dtype, + precision=precision, + ) + ] def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): for resnet in self.resnets: @@ -464,6 +646,10 @@ def __init__( temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.dim = dim self.z_dim = z_dim @@ -484,6 +670,10 @@ def __init__( out_channels=dims[0], kernel_size=3, padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) # downsample blocks @@ -491,15 +681,34 @@ def __init__( for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks for _ in range(num_res_blocks): - self.down_blocks.append(WanResidualBlock(in_dim=in_dim, out_dim=out_dim, dropout=dropout, rngs=rngs)) + self.down_blocks.append( + WanResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) + ) if scale in attn_scales: - self.down_blocks.append(WanAttentionBlock(dim=out_dim, rngs=rngs)) + self.down_blocks.append( + WanAttentionBlock( + dim=out_dim, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision + ) + ) in_dim = out_dim # downsample block if i != len(dim_mult) - 1: mode = "downsample3d" if temperal_downsample[i] else "downsample2d" - self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs)) + self.down_blocks.append( + WanResample( + out_dim, mode=mode, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision + ) + ) scale /= 2.0 # middle_blocks @@ -509,11 +718,25 @@ def __init__( dropout=dropout, non_linearity=non_linearity, num_layers=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) # output blocks self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs) - self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1) + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=z_dim, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: @@ -576,6 +799,10 @@ def __init__( temperal_upsample=[False, True, True], dropout=0.0, non_linearity: str = "silu", + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.dim = dim self.z_dim = z_dim @@ -591,10 +818,30 @@ def __init__( scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block - self.conv_in = WanCausalConv3d(rngs=rngs, in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1) + self.conv_in = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim, + out_channels=dims[0], + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) # middle_blocks - self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1) + self.mid_block = WanMidBlock( + dim=dims[0], + rngs=rngs, + dropout=dropout, + non_linearity=non_linearity, + num_layers=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) # upsample blocks self.up_blocks = [] @@ -616,6 +863,10 @@ def __init__( upsample_mode=upsample_mode, non_linearity=non_linearity, rngs=rngs, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) self.up_blocks.append(up_block) @@ -625,7 +876,17 @@ def __init__( # output blocks self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) - self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=3, kernel_size=3, padding=1) + self.conv_out = WanCausalConv3d( + rngs=rngs, + in_channels=out_dim, + out_channels=3, + kernel_size=3, + padding=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): if feat_cache is not None: @@ -737,6 +998,10 @@ def __init__( 2.8251, 1.9160, ], + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, ): self.z_dim = z_dim self.temperal_downsample = temperal_downsample @@ -753,13 +1018,31 @@ def __init__( attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) - self.quant_conv = WanCausalConv3d(rngs=rngs, in_channels=z_dim * 2, out_channels=z_dim * 2, kernel_size=1) + self.quant_conv = WanCausalConv3d( + rngs=rngs, + in_channels=z_dim * 2, + out_channels=z_dim * 2, + kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) + self.post_quant_conv = WanCausalConv3d( rngs=rngs, in_channels=z_dim, out_channels=z_dim, kernel_size=1, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) self.decoder = WanDecoder3d( @@ -771,6 +1054,10 @@ def __init__( attn_scales=attn_scales, temperal_upsample=self.temporal_upsample, dropout=dropout, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): @@ -826,12 +1113,17 @@ def _decode( # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. # Most likely due to an incorrect reshaping in the decoder. fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] - if len(fm1.shape) == 4: - fm1 = jnp.expand_dims(fm1, axis=0) - fm2 = jnp.expand_dims(fm2, axis=0) - fm3 = jnp.expand_dims(fm3, axis=0) - fm4 = jnp.expand_dims(fm4, axis=0) + # When batch_size is 0, expand batch dim for contatenation + # else, expand frame dim for concatenation so that batch dim stays intact. + axis = 0 + if fm1.shape[0] > 1: + axis = 1 + if len(fm1.shape) == 4: + fm1 = jnp.expand_dims(fm1, axis=axis) + fm2 = jnp.expand_dims(fm2, axis=axis) + fm3 = jnp.expand_dims(fm3, axis=axis) + fm4 = jnp.expand_dims(fm4, axis=axis) out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) out = jnp.clip(out, min=-1.0, max=1.0) feat_cache.clear_cache() diff --git a/src/maxdiffusion/models/wan/transformers/__init__.py b/src/maxdiffusion/models/wan/transformers/__init__.py new file mode 100644 index 000000000..9ff757fc3 --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py new file mode 100644 index 000000000..a084447b6 --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -0,0 +1,492 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Tuple, Optional, Dict, Union, Any +import math +import jax +import jax.numpy as jnp +from flax import nnx +import numpy as np +from .... import common_types +from ...modeling_flax_utils import FlaxModelMixin, get_activation +from ....configuration_utils import ConfigMixin, register_to_config +from ...embeddings_flax import ( + get_1d_rotary_pos_embed, + NNXFlaxTimesteps, + NNXTimestepEmbedding, + NNXPixArtAlphaTextProjection, +) +from ...normalization_flax import FP32LayerNorm +from ...attention_flax import FlaxWanAttention + +BlockSizes = common_types.BlockSizes + + +def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False) + freqs.append(freq) + freqs = jnp.concatenate(freqs, axis=1) + # sizes = jnp.array([ + # attention_head_dim // 2 - 2 * (attention_head_dim // 6), + # attention_head_dim // 6, + # attention_head_dim // 6, + # ]) + # cumulative_sizes = jnp.cumsum(jnp.array(sizes)) + # split_indices = cumulative_sizes[:-1] + t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6) + hw_size = attention_head_dim // 6 + + dims = [t_size, hw_size, hw_size] + + # Calculate split indices as a static list of integers + cumulative_sizes = np.cumsum(dims) + split_indices = cumulative_sizes[:-1].tolist() + freqs_split = jnp.split(freqs, split_indices, axis=1) + return freqs_split + + +class WanRotaryPosEmbed(nnx.Module): + + def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0): + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.theta = theta + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + _, num_frames, height, width, _ = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim) + + freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1) + freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1])) + + freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2) + freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1])) + + freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1) + freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1])) + + freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1) + freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1)) + return freqs_final + + +class WanTimeTextImageEmbedding(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + self.timesteps_proj = NNXFlaxTimesteps(dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0) + self.time_embedder = NNXTimestepEmbedding( + rngs=rngs, + in_channels=time_freq_dim, + time_embed_dim=dim, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) + self.act_fn = get_activation("silu") + self.time_proj = nnx.Linear( + rngs=rngs, + in_features=dim, + out_features=time_proj_dim, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "embed", + "mlp", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + self.text_embedder = NNXPixArtAlphaTextProjection( + rngs=rngs, + in_features=text_embed_dim, + hidden_size=dim, + act_fn="gelu_tanh", + ) + + def __call__( + self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None + ): + timestep = self.timesteps_proj(timestep) + temb = self.time_embedder(timestep) + + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + raise NotImplementedError("currently img2vid is not supported") + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class ApproximateGELU(nnx.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + """ + + def __init__( + self, + rngs: nnx.Rngs, + dim_in: int, + dim_out: int, + bias: bool, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + self.proj = nnx.Linear( + rngs=rngs, + in_features=dim_in, + out_features=dim_out, + use_bias=bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + ) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.proj(x) + return nnx.gelu(x) + + +class WanFeedForward(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim: int = None, + bias: bool = True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.act_fn = None + if activation_fn == "gelu-approximate": + self.act_fn = ApproximateGELU( + rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision + ) + else: + raise NotImplementedError(f"{activation_fn} is not implemented.") + + self.proj_out = nnx.Linear( + rngs=rngs, + in_features=inner_dim, + out_features=dim_out, + use_bias=bias, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "mlp", + "embed", + ), + ), + ) + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + hidden_states = self.act_fn(hidden_states) + return self.proj_out(hidden_states) + + +class WanTransformerBlock(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + # In torch, this is none, so it can be ignored. + # added_kv_proj_dim: Optional[int] = None, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + + # 1. Self-attention + self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) + self.attn1 = FlaxWanAttention( + rngs=rngs, + query_dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention_kernel=attention, + ) + + # 1. Cross-attention + self.attn2 = FlaxWanAttention( + rngs=rngs, + query_dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention_kernel=attention, + ) + assert cross_attn_norm is True + self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) + + # 3. Feed-forward + self.ffn = WanFeedForward( + rngs=rngs, + dim=dim, + inner_dim=ffn_dim, + activation_fn="gelu-approximate", + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) + self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) + + key = rngs.params() + self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) + + def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array): + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (self.scale_shift_table + temb), 6, axis=1 + ) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) + attn_output = self.attn1( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb + ) + hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype) + return hidden_states + + +class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): + + @register_to_config + def __init__( + self, + rngs: nnx.Rngs, + model_type="t2v", + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + ): + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nnx.Conv( + in_channels, + inner_dim, + rngs=rngs, + kernel_size=patch_size, + strides=patch_size, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + None, + None, + None, + None, + "conv_out", + ), + ), + ) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = WanTimeTextImageEmbedding( + rngs=rngs, + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + blocks = [] + for _ in range(num_layers): + block = WanTransformerBlock( + rngs=rngs, + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention=attention, + ) + blocks.append(block) + self.blocks = blocks + + self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) + self.proj_out = nnx.Linear( + rngs=rngs, + in_features=inner_dim, + out_features=out_channels * math.prod(patch_size), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), + ) + key = rngs.params() + self.scale_shift_table = nnx.Param( + jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), + ) + + def __call__( + self, + hidden_states: jax.Array, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + is_uncond: jax.Array, # jnp.bool_ scalar + slg_mask: jax.Array, # jnp.bool_ array of shape (num_blocks,) + encoder_hidden_states_image: Optional[jax.Array] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[jax.Array, Dict[str, jax.Array]]: + batch_size, _, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) + rotary_emb = self.rope(hidden_states) + hidden_states = self.patch_embedding(hidden_states) + hidden_states = jax.lax.collapse(hidden_states, 1, -1) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) + + if encoder_hidden_states_image is not None: + raise NotImplementedError("img2vid is not yet implemented.") + for block_idx, block in enumerate(self.blocks): + should_skip_block = slg_mask[block_idx] & is_uncond + hidden_states = jax.lax.cond( + should_skip_block, + lambda hs: hs, # If true, pass through original hidden_states (skip block) + lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), + hidden_states, + ) + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) + + hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) + hidden_states = jax.lax.collapse(hidden_states, 6, None) + hidden_states = jax.lax.collapse(hidden_states, 4, 6) + hidden_states = jax.lax.collapse(hidden_states, 2, 4) + return hidden_states diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 1a9948fdb..f84346735 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,9 +1,10 @@ +import json import jax import jax.numpy as jnp from maxdiffusion import max_logging from huggingface_hub import hf_hub_download from safetensors import safe_open -from flax.traverse_util import unflatten_dict +from flax.traverse_util import unflatten_dict, flatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) @@ -17,6 +18,69 @@ def _tuple_str_to_int(in_tuple): return tuple(out_list) +def rename_for_nnx(key): + new_key = key + if "norm_k" in key or "norm_q" in key: + new_key = key[:-1] + ("scale",) + return new_key + + +def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + # download the index file for sharded models. + index_file_path = hf_hub_download( + pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json" + ) + # open the index file. + with open(index_file_path, "r") as f: + index_dict = json.load(f) + model_files = set() + for key in index_dict["weight_map"].keys(): + model_files.add(index_dict["weight_map"][key]) + + model_files = list(model_files) + tensors = {} + for model_file in model_files: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file) + # now get all the filenames for the model that need downloading + max_logging.log(f"Load and port Wan 2.1 transformer on {device}") + + if ckpt_shard_path is not None: + with safe_open(ckpt_shard_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + flattened_dict = flatten_dict(eval_shapes) + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = {} + for key in flattened_dict: + string_tuple = tuple([str(item) for item in key]) + random_flax_state_dict[string_tuple] = flattened_dict[key] + del flattened_dict + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") + renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") + renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict + + def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] with jax.default_device(device): diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py new file mode 100644 index 000000000..83a537f82 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/__init__.py @@ -0,0 +1,17 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from .wan_pipeline import WanPipeline diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py new file mode 100644 index 000000000..8d9a2986b --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -0,0 +1,502 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union, Optional +from functools import partial +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +import flax +import flax.linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from ...pyconfig import HyperParameters +from ... import max_logging +from ... import max_utils +from ...max_utils import get_flash_block_sizes, get_precision +from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae +from ...models.wan.transformers.transformer_wan import WanModel +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache +from maxdiffusion.video_processor import VideoProcessor +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState +from transformers import AutoTokenizer, UMT5EncoderModel +import ftfy +import html +import re +import torch + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState: + vs.sharding_rules = logical_axis_rules + return vs + + +# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. +def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + + def create_model(rngs: nnx.Rngs, wan_config: dict): + wan_transformer = WanModel(**wan_config, rngs=rngs) + return wan_transformer + + # 1. Load config. + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") + wan_config["mesh"] = mesh + wan_config["dtype"] = config.activations_dtype + wan_config["weights_dtype"] = config.weights_dtype + wan_config["attention"] = config.attention + wan_config["precision"] = get_precision(config) + wan_config["flash_block_sizes"] = get_flash_block_sizes(config) + + # 2. eval_shape - will not use flops or create weights on device + # thus not using HBM memory. + p_model_factory = partial(create_model, wan_config=wan_config) + wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) + graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) + + # 3. retrieve the state shardings, mapping logical names to mesh axis names. + logical_state_spec = nnx.get_partition_spec(state) + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) + params = state.to_pure_dict() + state = dict(nnx.to_flat_state(state)) + + # 4. Load pretrained weights and move them to device using the state shardings from (3) above. + # This helps with loading sharded weights directly into the accelerators without fist copying them + # all to one device and then distributing them, thus using low HBM memory. + params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + for path, val in flax.traverse_util.flatten_dict(params).items(): + sharding = logical_state_sharding[path].value + state[path].value = jax.device_put(val, sharding) + state = nnx.from_flat_state(state) + + wan_transformer = nnx.merge(graphdef, state, rest_of_state) + return wan_transformer + + +@nnx.jit(static_argnums=(1,), donate_argnums=(0,)) +def create_sharded_logical_model(model, logical_axis_rules): + graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) + p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules) + state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + pspecs = nnx.get_partition_spec(state) + sharded_state = jax.lax.with_sharding_constraint(state, pspecs) + model = nnx.merge(graphdef, sharded_state, rest_of_state) + return model + + +class WanPipeline: + r""" + Pipeline for text-to-video generation using Wan. + + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlaxUniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanModel, + vae: AutoencoderKLWan, + vae_cache: AutoencoderKLWanCache, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state: UniPCMultistepSchedulerState, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + ): + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.transformer = transformer + self.vae = vae + self.vae_cache = vae_cache + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.devices_array = devices_array + self.mesh = mesh + self.config = config + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.p_run_inference = None + + @classmethod + def load_text_encoder(cls, config: HyperParameters): + text_encoder = UMT5EncoderModel.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="text_encoder", + ) + return text_encoder + + @classmethod + def load_tokenizer(cls, config: HyperParameters): + tokenizer = AutoTokenizer.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="tokenizer", + ) + return tokenizer + + @classmethod + def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + + graphdef, state = nnx.split(wan_vae, nnx.Param) + params = state.to_pure_dict() + # This replaces random params with the model. + params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + params = jax.device_put(params, PositionalSharding(devices_array).replicate()) + wan_vae = nnx.merge(graphdef, params) + p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) + # Shard + with mesh: + wan_vae = p_create_sharded_logical_model(model=wan_vae) + return wan_vae, vae_cache + + @classmethod + def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + with mesh: + wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + return wan_transformer + + @classmethod + def load_scheduler(cls, config): + scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="scheduler", + flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p + ) + return scheduler, scheduler_state + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + transformer = None + tokenizer = None + scheduler = None + scheduler_state = None + text_encoder = None + if not vae_only: + with mesh: + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) + + scheduler, scheduler_state = cls.load_scheduler(config=config) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + return WanPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=wan_vae, + vae_cache=vae_cache, + scheduler=scheduler, + scheduler_state=scheduler_state, + devices_array=devices_array, + mesh=mesh, + config=config, + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) + + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size: int, + vae_scale_factor_temporal: int, + vae_scale_factor_spatial: int, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_channels_latents: int = 16, + ): + rng = jax.random.key(self.config.seed) + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // vae_scale_factor_spatial, + int(width) // vae_scale_factor_spatial, + ) + latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype) + + return latents + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + slg_layers: List[int] = None, + slg_start: float = 0.0, + slg_end: float = 1.0, + ): + if not vae_only: + if num_frames % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + prompt = [prompt] + + batch_size = len(prompt) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + num_channel_latents = self.transformer.config.in_channels + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + vae_scale_factor_temporal=self.vae_scale_factor_temporal, + vae_scale_factor_spatial=self.vae_scale_factor_spatial, + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channel_latents, + ) + + data_sharding = PositionalSharding(self.devices_array).replicate() + if len(prompt) % jax.device_count() == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + + latents = jax.device_put(latents, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + ) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + slg_layers=slg_layers, + slg_start=slg_start, + slg_end=slg_end, + num_transformer_layers=self.transformer.config.num_layers, + ) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(self.config.weights_dtype) + + video = self.vae.decode(latents, self.vae_cache)[0] + + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + video = self.video_processor.postprocess_video(video, output_type="np") + return video + + +@jax.jit +def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask): + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + return wan_transformer( + hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask + ) + + +def run_inference( + graphdef, + sharded_state, + rest_of_state, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale: float, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + num_transformer_layers: int, + scheduler_state, + slg_layers: List[int] = None, + slg_start: float = 0.0, + slg_end: float = 1.0, +): + do_classifier_free_guidance = guidance_scale > 1.0 + for step in range(num_inference_steps): + slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) + if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): + slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents.shape[0]) + + noise_pred = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + is_uncond=jnp.array(False, dtype=jnp.bool_), + slg_mask=slg_mask, + ) + + if do_classifier_free_guidance: + noise_uncond = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + negative_prompt_embeds, + is_uncond=jnp.array(True, dtype=jnp.bool_), + slg_mask=slg_mask, + ) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index 85e116137..b04a142de 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -33,875 +33,770 @@ @flax.struct.dataclass class UniPCMultistepSchedulerState: - """ - Data class to hold the mutable state of the FlaxUniPCMultistepScheduler. - """ - - common: CommonSchedulerState - - # Core schedule parameters (derived from CommonSchedulerState in create_state) - sigmas: jnp.ndarray - alpha_t: jnp.ndarray - sigma_t: jnp.ndarray - lambda_t: jnp.ndarray - init_noise_sigma: float - - # History buffers for multi-step solver - # `model_outputs` stores previous converted model outputs (e.g., predicted x0 or epsilon) - timesteps: jnp.ndarray = None - model_outputs: jnp.ndarray = None - timestep_list: jnp.ndarray = ( - None # Stores corresponding timesteps for `model_outputs` + """ + Data class to hold the mutable state of the FlaxUniPCMultistepScheduler. + """ + + common: CommonSchedulerState + + # Core schedule parameters (derived from CommonSchedulerState in create_state) + sigmas: jnp.ndarray + alpha_t: jnp.ndarray + sigma_t: jnp.ndarray + lambda_t: jnp.ndarray + init_noise_sigma: float + + # History buffers for multi-step solver + # `model_outputs` stores previous converted model outputs (e.g., predicted x0 or epsilon) + timesteps: jnp.ndarray = None + model_outputs: jnp.ndarray = None + timestep_list: jnp.ndarray = None # Stores corresponding timesteps for `model_outputs` + + # State variables for tracking progress and solver order + lower_order_nums: int = 0 + last_sample: Optional[jnp.ndarray] = None # Sample from the previous predictor step + step_index: Optional[int] = None + begin_index: Optional[int] = None # Used for img2img/inpaing + this_order: int = 0 # Current effective order of the UniPC solver for this step + + @classmethod + def create( + cls, + common_state: CommonSchedulerState, + alpha_t: jnp.ndarray, + sigma_t: jnp.ndarray, + lambda_t: jnp.ndarray, + sigmas: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + ): + return cls( + common=common_state, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + lower_order_nums=0, + last_sample=None, + step_index=None, + begin_index=None, + this_order=0, ) - # State variables for tracking progress and solver order - lower_order_nums: int = 0 - last_sample: Optional[jnp.ndarray] = None # Sample from the previous predictor step - step_index: Optional[int] = None - begin_index: Optional[int] = None # Used for img2img/inpaing - this_order: int = 0 # Current effective order of the UniPC solver for this step - - @classmethod - def create( - cls, - common_state: CommonSchedulerState, - alpha_t: jnp.ndarray, - sigma_t: jnp.ndarray, - lambda_t: jnp.ndarray, - sigmas: jnp.ndarray, - init_noise_sigma: jnp.ndarray, - ): - return cls( - common=common_state, - alpha_t=alpha_t, - sigma_t=sigma_t, - lambda_t=lambda_t, - sigmas=sigmas, - init_noise_sigma=init_noise_sigma, - lower_order_nums=0, - last_sample=None, - step_index=None, - begin_index=None, - this_order=0, - ) - @flax.struct.dataclass(frozen=False) class FlaxUniPCMultistepSchedulerOutput(FlaxSchedulerOutput): - state: UniPCMultistepSchedulerState + state: UniPCMultistepSchedulerState class FlaxUniPCMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + `FlaxUniPCMultistepScheduler` is a JAX/Flax training-free framework designed for the fast sampling of diffusion models. + It implements the UniPC (Unified Predictor-Corrector) algorithm for efficient diffusion model sampling. + """ + + dtype: jnp.dtype + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: Optional[FlaxSchedulerMixin] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + use_flow_sigmas: Optional[bool] = False, + flow_shift: Optional[float] = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", + rescale_zero_terminal_snr: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + # Validation checks from original __init__ + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if self.config.solver_type not in ["bh1", "bh2"]: + raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> UniPCMultistepSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + if self.config.get("rescale_zero_terminal_snr", False): + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + alphas_cumprod = common.alphas_cumprod + alphas_cumprod = alphas_cumprod.at[-1].set(2**-24) + common = common.replace(alphas_cumprod=alphas_cumprod) + + # Currently we only support VP-type noise schedule + alpha_t = jnp.sqrt(common.alphas_cumprod) + sigma_t = jnp.sqrt(1 - common.alphas_cumprod) + lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + if self.config.solver_type not in ["bh1", "bh2"]: + if self.config.solver_type in ["midpoint", "heun", "logrho"]: + self.config.solver_type = "bh2" + else: + raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") + + return UniPCMultistepSchedulerState.create( + common_state=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + ) + + def set_begin_index(self, state: UniPCMultistepSchedulerState, begin_index: int = 0) -> UniPCMultistepSchedulerState: """ - `FlaxUniPCMultistepScheduler` is a JAX/Flax training-free framework designed for the fast sampling of diffusion models. - It implements the UniPC (Unified Predictor-Corrector) algorithm for efficient diffusion model sampling. + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. """ - - dtype: jnp.dtype - - @property - def has_state(self) -> bool: - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: List[int] = [], - solver_p: Optional[FlaxSchedulerMixin] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - use_flow_sigmas: Optional[bool] = False, - flow_shift: Optional[float] = 1.0, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - final_sigmas_type: Optional[str] = "zero", - rescale_zero_terminal_snr: bool = False, - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - # Validation checks from original __init__ - if self.config.use_beta_sigmas and not is_scipy_available(): - raise ImportError( - "Make sure to install scipy if you want to use beta sigmas." - ) - if ( - sum( - [ - self.config.use_beta_sigmas, - self.config.use_exponential_sigmas, - self.config.use_karras_sigmas, - ] - ) - > 1 - ): - raise ValueError( - "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." - ) - if self.config.solver_type not in ["bh1", "bh2"]: - raise NotImplementedError( - f"{self.config.solver_type} is not implemented for {self.__class__}" - ) - - def create_state( - self, common: Optional[CommonSchedulerState] = None - ) -> UniPCMultistepSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) - - if self.config.get("rescale_zero_terminal_snr", False): - # Close to 0 without being 0 so first sigma is not inf - # FP16 smallest positive subnormal works well here - alphas_cumprod = common.alphas_cumprod - alphas_cumprod = alphas_cumprod.at[-1].set(2**-24) - common = common.replace(alphas_cumprod=alphas_cumprod) - - # Currently we only support VP-type noise schedule - alpha_t = jnp.sqrt(common.alphas_cumprod) - sigma_t = jnp.sqrt(1 - common.alphas_cumprod) - lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) - sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 - - # standard deviation of the initial noise distribution - init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - - if self.config.solver_type not in ["bh1", "bh2"]: - if self.config.solver_type in ["midpoint", "heun", "logrho"]: - self.config.solver_type = "bh2" - else: - raise NotImplementedError( - f"{self.config.solver_type} is not implemented for {self.__class__}" - ) - - return UniPCMultistepSchedulerState.create( - common_state=common, - alpha_t=alpha_t, - sigma_t=sigma_t, - lambda_t=lambda_t, - sigmas=sigmas, - init_noise_sigma=init_noise_sigma, + return state.replace(begin_index=begin_index) + + def set_timesteps( + self, + state: UniPCMultistepSchedulerState, + num_inference_steps: int, + shape: Tuple, + ) -> UniPCMultistepSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + """ + #### Copied from scheduling_dmpsolver_multistep_flax + last_timestep = self.config.num_train_timesteps + if self.config.timestep_spacing == "linspace": + timesteps = jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + # initial running values + sigmas = state.sigmas + + # TODO + # # Apply Karras/Exponential/Beta/Flow Sigmas if configured + if self.config.use_karras_sigmas: + # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_karras_sigmas` is not implemented in JAX version yet.") + elif self.config.use_exponential_sigmas: + # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_exponential_sigmas` is not implemented in JAX version yet.") + elif self.config.use_beta_sigmas: + # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_beta_sigmas` is not implemented in JAX version yet.") + if self.config.use_flow_sigmas: + alphas = jnp.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) + sigmas = 1.0 - alphas + sigmas = jnp.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() + timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) - - def set_begin_index( - self, state: UniPCMultistepSchedulerState, begin_index: int = 0 - ) -> UniPCMultistepSchedulerState: - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - """ - return state.replace(begin_index=begin_index) - - def set_timesteps( - self, - state: UniPCMultistepSchedulerState, - num_inference_steps: int, - shape: Tuple, - ) -> UniPCMultistepSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - """ - #### Copied from scheduling_dmpsolver_multistep_flax - last_timestep = self.config.num_train_timesteps - if self.config.timestep_spacing == "linspace": - timesteps = ( - jnp.linspace(0, last_timestep - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .astype(jnp.int32) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (jnp.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(jnp.int32) - ) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - jnp.arange(last_timestep, 0, -step_ratio) - .round() - .copy() - .astype(jnp.int32) - ) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) - - # initial running values - sigmas = state.sigmas - - # TODO - # # Apply Karras/Exponential/Beta/Flow Sigmas if configured - if self.config.use_karras_sigmas: - # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - raise NotImplementedError( - "`use_karras_sigmas` is not implemented in JAX version yet." - ) - elif self.config.use_exponential_sigmas: - # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - raise NotImplementedError( - "`use_exponential_sigmas` is not implemented in JAX version yet." - ) - elif self.config.use_beta_sigmas: - # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - raise NotImplementedError( - "`use_beta_sigmas` is not implemented in JAX version yet." - ) - if self.config.use_flow_sigmas: - alphas = jnp.linspace( - 1, 1 / self.config.num_train_timesteps, num_inference_steps + 1 - ) - sigmas = 1.0 - alphas - sigmas = jnp.flip( - self.config.flow_shift - * sigmas - / (1 + (self.config.flow_shift - 1) * sigmas) - )[:-1].copy() - timesteps = ( - (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64) - ) - if self.config.final_sigmas_type == "sigma_min": - sigma_last = sigmas[-1] - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype( - jnp.float32 - ) - else: # Default case if none of the specialized sigmas are used - sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ( - (1 - state.common.alphas_cumprod[0]) - / state.common.alphas_cumprod[0] - ) ** 0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype( - jnp.float32 - ) - - model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) - timestep_list = jnp.zeros( - (self.config.solver_order,), dtype=jnp.int32 # Timesteps are integers + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32) + else: # Default case if none of the specialized sigmas are used + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - state.common.alphas_cumprod[0]) / state.common.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) - # Update the state with the new schedule and re-initialized history - return state.replace( - timesteps=timesteps, - sigmas=sigmas, - model_outputs=model_outputs, - timestep_list=timestep_list, - lower_order_nums=0, # Reset counters for a new inference run - step_index=None, - begin_index=None, - last_sample=None, - this_order=0, + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32) + + model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) + timestep_list = jnp.zeros((self.config.solver_order,), dtype=jnp.int32) # Timesteps are integers + # Update the state with the new schedule and re-initialized history + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + model_outputs=model_outputs, + timestep_list=timestep_list, + lower_order_nums=0, # Reset counters for a new inference run + step_index=None, + begin_index=None, + last_sample=None, + this_order=0, + ) + + def convert_model_output( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + Converts the model output based on the prediction type and current state. + """ + sigma = state.sigmas[state.step_index] # Current sigma + + # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.config.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + # Original code has `sigma_t = self.sigmas[self.step_index]`. + # This implies current sigma `sigma` is used as sigma_t for flow. + x0_pred = sample - sigma * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." ) - def convert_model_output( - self, - state: UniPCMultistepSchedulerState, - model_output: jnp.ndarray, - sample: jnp.ndarray, - ) -> jnp.ndarray: - """ - Converts the model output based on the prediction type and current state. - """ - sigma = state.sigmas[state.step_index] # Current sigma - - # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - - if self.config.predict_x0: - if self.config.prediction_type == "epsilon": - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - x0_pred = alpha_t * sample - sigma_t * model_output - elif self.config.prediction_type == "flow_prediction": - # Original code has `sigma_t = self.sigmas[self.step_index]`. - # This implies current sigma `sigma` is used as sigma_t for flow. - x0_pred = sample - sigma * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " - "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - raise NotImplementedError("Dynamic thresholding isn't implemented.") - # x0_pred = self._threshold_sample(x0_pred) - return x0_pred - else: # self.config.predict_x0 is False - if self.config.prediction_type == "epsilon": - return model_output - elif self.config.prediction_type == "sample": - epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon - elif self.config.prediction_type == "v_prediction": - epsilon = alpha_t * model_output + sigma_t * sample - return epsilon - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the UniPCMultistepScheduler." - ) - - def multistep_uni_p_bh_update( - self, - state: UniPCMultistepSchedulerState, - model_output: jnp.ndarray, - sample: jnp.ndarray, - order: int, - ) -> jnp.ndarray: - """ - One step for the UniP (B(h) version) - the Predictor. - """ - if self.config.solver_p: - raise NotImplementedError( - "Nested `solver_p` is not implemented in JAX version yet." - ) - - m0 = state.model_outputs[ - self.config.solver_order - 1 - ] # Most recent stored converted model output - x = sample - - sigma_t_val, sigma_s0_val = ( - state.sigmas[state.step_index + 1], - state.sigmas[state.step_index], + if self.config.thresholding: + raise NotImplementedError("Dynamic thresholding isn't implemented.") + # x0_pred = self._threshold_sample(x0_pred) + return x0_pred + else: # self.config.predict_x0 is False + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." ) - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + def multistep_uni_p_bh_update( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + order: int, + ) -> jnp.ndarray: + """ + One step for the UniP (B(h) version) - the Predictor. + """ + if self.config.solver_p: + raise NotImplementedError("Nested `solver_p` is not implemented in JAX version yet.") - lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) - lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + m0 = state.model_outputs[self.config.solver_order - 1] # Most recent stored converted model output + x = sample - h = lambda_t - lambda_s0 + sigma_t_val, sigma_s0_val = ( + state.sigmas[state.step_index + 1], + state.sigmas[state.step_index], + ) - def rk_d1_loop_body(i, carry): - # Loop from i = 0 to order-2 - rks, D1s = carry - history_idx = self.config.solver_order - 2 - i - mi = state.model_outputs[history_idx] - si_val = state.timestep_list[history_idx] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + + h = lambda_t - lambda_s0 + + def rk_d1_loop_body(i, carry): + # Loop from i = 0 to order-2 + rks, D1s = carry + history_idx = self.config.solver_order - 2 - i + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[history_idx] + + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)]) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + + rk = (lambda_si - lambda_s0) / h + Di = (mi - m0) / rk + + rks = rks.at[i].set(rk) + D1s = D1s.at[i].set(Di) + return rks, D1s + + rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) + if self.config.solver_order == 1: + # Dummy D1s array. It will not be used if order == 1 + D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) + rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init)) + rks = rks.at[order - 1].set(1.0) + + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() + + def rb_loop_body(i, carry): + R, b, current_h_phi_k, factorial_val = carry + R = R.at[i].set(jnp.power(rks, i)) + b = b.at[i].set(current_h_phi_k * factorial_val / B_h) + + def update_fn(vals): + _h_phi_k, _fac = vals + next_fac = _fac * (i + 2) + next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac + return next_h_phi_k, next_fac + + current_h_phi_k, factorial_val = jax.lax.cond( + i < order - 1, + update_fn, + lambda vals: vals, + (current_h_phi_k, factorial_val), + ) + return R, b, current_h_phi_k, factorial_val + + R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype) + b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + init_h_phi_k = h_phi_1 / hh - 1.0 + init_factorial = 1.0 + R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial)) + + if len(D1s) > 0: + D1s = jnp.stack(D1s, axis=1) # Resulting shape (B, K, C, H, W) + + def solve_for_rhos_p(R_mat, b_vec, current_order): + # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix + mask_size = self.config.solver_order - 1 + mask = jnp.arange(mask_size) < (current_order - 1) + mask_2d = mask[:, None] & mask[None, :] + + # Pad R with identity and b with zeros for a safe solve + R_safe = jnp.where( + mask_2d, + R_mat[:mask_size, :mask_size], + jnp.eye(mask_size, dtype=R_mat.dtype), + ) + b_safe = jnp.where(mask, b_vec[:mask_size], 0.0) + + # Solve the system and mask the result + solved_rhos = jnp.linalg.solve(R_safe, b_safe) + return jnp.where(mask, solved_rhos, 0.0) + + # Handle the special case for order == 2 + if self.config.solver_order == 1: + # Dummy rhos_p_padded for tracing. + rhos_p_order2 = jnp.zeros(1, dtype=x.dtype) + else: + rhos_p_order2 = jnp.zeros(self.config.solver_order - 1, dtype=x.dtype).at[0].set(0.5) + + # Get the result for the general case + rhos_p_general = solve_for_rhos_p(R, b, order) + + # Select the appropriate result based on the order + rhos_p = jnp.where(order == 2, rhos_p_order2, rhos_p_general) + + pred_res = jax.lax.cond( + order > 1, + lambda _: jnp.einsum("k,bkc...->bc...", rhos_p, D1s).astype(x.dtype), + # False branch: return a zero tensor with the correct shape. + lambda _: jnp.zeros_like(x), + operand=None, + ) - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t( - state.sigmas[self.index_for_timestep(state, si_val)] - ) - lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: # Predict epsilon + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + x_t = x_t_ - sigma_t * B_h * pred_res + + return x_t.astype(x.dtype) + + def multistep_uni_c_bh_update( + self, + state: UniPCMultistepSchedulerState, + this_model_output: jnp.ndarray, + last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}` + this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step) + order: int, + ) -> jnp.ndarray: + """ + One step for the UniC (B(h) version) - the Corrector. + """ + model_output_list = state.model_outputs + m0 = model_output_list[self.config.solver_order - 1] # Most recent model output from history - rk = (lambda_si - lambda_s0) / h - Di = (mi - m0) / rk + if last_sample is not None: + x = last_sample + else: + # If it's None, create dummy data. This is for the tracing purpose + x = jnp.zeros_like(this_sample) - rks = rks.at[i].set(rk) - D1s = D1s.at[i].set(Di) - return rks, D1s + x_t = this_sample - rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) - D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) - if self.config.solver_order == 1: - # Dummy D1s array. It will not be used if order == 1 - D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) - rks, D1s = jax.lax.fori_loop( - 0, order - 1, rk_d1_loop_body, (rks_init, D1s_init) - ) - rks = rks.at[order - 1].set(1.0) - - hh = -h if self.config.predict_x0 else h - h_phi_1 = jnp.expm1(hh) - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = jnp.expm1(hh) - else: - raise NotImplementedError() - - def rb_loop_body(i, carry): - R, b, current_h_phi_k, factorial_val = carry - R = R.at[i].set(jnp.power(rks, i)) - b = b.at[i].set(current_h_phi_k * factorial_val / B_h) - - def update_fn(vals): - _h_phi_k, _fac = vals - next_fac = _fac * (i + 2) - next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac - return next_h_phi_k, next_fac - - current_h_phi_k, factorial_val = jax.lax.cond( - i < order - 1, - update_fn, - lambda vals: vals, - (current_h_phi_k, factorial_val), - ) - return R, b, current_h_phi_k, factorial_val - - R_init = jnp.zeros( - (self.config.solver_order, self.config.solver_order), dtype=h.dtype - ) - b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) - init_h_phi_k = h_phi_1 / hh - 1.0 - init_factorial = 1.0 - R, b, _, _ = jax.lax.fori_loop( - 0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial) - ) + model_t = this_model_output - if len(D1s) > 0: - D1s = jnp.stack(D1s, axis=1) # Resulting shape (B, K, C, H, W) - - def solve_for_rhos_p(R_mat, b_vec, current_order): - # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix - mask_size = self.config.solver_order - 1 - mask = jnp.arange(mask_size) < (current_order - 1) - mask_2d = mask[:, None] & mask[None, :] - - # Pad R with identity and b with zeros for a safe solve - R_safe = jnp.where( - mask_2d, - R_mat[:mask_size, :mask_size], - jnp.eye(mask_size, dtype=R_mat.dtype), - ) - b_safe = jnp.where(mask, b_vec[:mask_size], 0.0) - - # Solve the system and mask the result - solved_rhos = jnp.linalg.solve(R_safe, b_safe) - return jnp.where(mask, solved_rhos, 0.0) - - # Handle the special case for order == 2 - if self.config.solver_order == 1: - # Dummy rhos_p_padded for tracing. - rhos_p_order2 = jnp.zeros(1, dtype=x.dtype) - else: - rhos_p_order2 = ( - jnp.zeros(self.config.solver_order - 1, dtype=x.dtype).at[0].set(0.5) - ) - - # Get the result for the general case - rhos_p_general = solve_for_rhos_p(R, b, order) - - # Select the appropriate result based on the order - rhos_p = jnp.where(order == 2, rhos_p_order2, rhos_p_general) - - pred_res = jax.lax.cond( - order > 1, - lambda _: jnp.einsum("k,bkc...->bc...", rhos_p, D1s).astype(x.dtype), - # False branch: return a zero tensor with the correct shape. - lambda _: jnp.zeros_like(x), - operand=None, - ) + sigma_t_val = state.sigmas[state.step_index] + sigma_s0_val = state.sigmas[state.step_index - 1] - if self.config.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - x_t = x_t_ - alpha_t * B_h * pred_res - else: # Predict epsilon - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - x_t = x_t_ - sigma_t * B_h * pred_res - - return x_t.astype(x.dtype) - - def multistep_uni_c_bh_update( - self, - state: UniPCMultistepSchedulerState, - this_model_output: jnp.ndarray, - last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}` - this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step) - order: int, - ) -> jnp.ndarray: - """ - One step for the UniC (B(h) version) - the Corrector. - """ - model_output_list = state.model_outputs - m0 = model_output_list[ - self.config.solver_order - 1 - ] # Most recent model output from history - - if last_sample is not None: - x = last_sample - else: - # If it's None, create dummy data. This is for the tracing purpose - x = jnp.zeros_like(this_sample) - - x_t = this_sample - - model_t = this_model_output - - sigma_t_val = state.sigmas[state.step_index] - sigma_s0_val = state.sigmas[state.step_index - 1] - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) - - lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) - lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) - - h = lambda_t - lambda_s0 - - def rk_d1_loop_body(i, carry): - # Loop from i = 0 to order-1. - rks, D1s = carry - - # Get history from state buffer - history_idx = self.config.solver_order - (i + 2) - mi = state.model_outputs[history_idx] - si_val = state.timestep_list[ - history_idx - ] # This is the actual timestep value - - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t( - state.sigmas[self.index_for_timestep(state, si_val)] - ) - lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) - - rk = (lambda_si - lambda_s0) / h - Di = (mi - m0) / rk - - # Update pre-allocated arrays - rks = rks.at[i].set(rk) - D1s = D1s.at[i].set(Di) - return rks, D1s - - # Pre-allocate arrays to max possible size - rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) - D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) - if self.config.solver_order == 1: - # Dummy D1s array. It will not be used if order == 1. This is for tracing. - D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) - - # Run the loop up to `order - 1` - rks, D1s = jax.lax.fori_loop( - 0, order - 1, rk_d1_loop_body, (rks_init, D1s_init) - ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) - rks = rks.at[order - 1].set(1.0) - - hh = -h if self.config.predict_x0 else h - h_phi_1 = jnp.expm1(hh) - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = jnp.expm1(hh) - else: - raise NotImplementedError() - - def rb_loop_body(i, carry): - # Loop from i = 0 to order-1 - R, b, current_h_phi_k, factorial_val = carry - - R = R.at[i].set(jnp.power(rks, i)) - b = b.at[i].set(current_h_phi_k * factorial_val / B_h) - - # Conditionally update phi_k and factorial for the next iteration - def update_fn(vals): - # This branch is taken if i < order - 1 - _h_phi_k, _fac = vals - next_fac = _fac * (i + 2) - next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac - return next_h_phi_k, next_fac - - current_h_phi_k, factorial_val = jax.lax.cond( - i < order - 1, - update_fn, # If true, update values - lambda vals: vals, # If false, pass through - (current_h_phi_k, factorial_val), - ) - return R, b, current_h_phi_k, factorial_val - - # Pre-allocate R and b to max size - R_init = jnp.zeros( - (self.config.solver_order, self.config.solver_order), dtype=h.dtype - ) - b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) - # Initialize loop carriers - init_h_phi_k = h_phi_1 / hh - 1.0 - init_factorial = 1.0 + h = lambda_t - lambda_s0 - R, b, _, _ = jax.lax.fori_loop( - 0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial) - ) + def rk_d1_loop_body(i, carry): + # Loop from i = 0 to order-1. + rks, D1s = carry - if len(D1s) > 0: - D1s = jnp.stack(D1s, axis=1) # (B, K, C, H, W) + # Get history from state buffer + history_idx = self.config.solver_order - (i + 2) + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[history_idx] # This is the actual timestep value - def solve_for_rhos(R_mat, b_vec, current_order): - # Create a mask to select the first `current_order` elements - mask = jnp.arange(self.config.solver_order) < current_order - mask_2d = mask[:, None] & mask[None, :] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)]) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) - # Pad R with identity and b with zeros to create a safe, full-sized system - R_safe = jnp.where( - mask_2d, R_mat, jnp.eye(self.config.solver_order, dtype=R_mat.dtype) - ) - b_safe = jnp.where(mask, b_vec, 0.0) + rk = (lambda_si - lambda_s0) / h + Di = (mi - m0) / rk - # Solve the full-size system and mask the result - solved_rhos = jnp.linalg.solve(R_safe, b_safe) - return jnp.where(mask, solved_rhos, 0.0) + # Update pre-allocated arrays + rks = rks.at[i].set(rk) + D1s = D1s.at[i].set(Di) + return rks, D1s - rhos_c_order1 = ( - jnp.zeros(self.config.solver_order, dtype=x_t.dtype).at[0].set(0.5) - ) - rhos_c_general = solve_for_rhos(R, b, order) - rhos_c = jnp.where(order == 1, rhos_c_order1, rhos_c_general) + # Pre-allocate arrays to max possible size + rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) + if self.config.solver_order == 1: + # Dummy D1s array. It will not be used if order == 1. This is for tracing. + D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) - D1_t = model_t - m0 + # Run the loop up to `order - 1` + rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init)) - corr_res = jax.lax.cond( - order > 1, - lambda _: (jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)), - lambda _: jnp.zeros_like(D1_t), - operand=None, - ) + rks = rks.at[order - 1].set(1.0) - final_rho = jnp.dot( - rhos_c, - jax.nn.one_hot(order - 1, self.config.solver_order, dtype=rhos_c.dtype), - ) + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) - if self.config.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - x_t = x_t_ - alpha_t * B_h * (corr_res + final_rho * D1_t) - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t) - - return x_t.astype(x.dtype) - - def index_for_timestep( - self, - state: UniPCMultistepSchedulerState, - timestep: Union[int, jnp.ndarray], - schedule_timesteps: Optional[jnp.ndarray] = None, - ) -> int: - """ "Gets the step_index for timestep.""" - if schedule_timesteps is None: - schedule_timesteps = state.timesteps - - # QUINN!! - # timestep_val = ( - # timestep.item() - # if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0 - # else timestep - # ) - timestep_val = timestep - - index_candidates = jnp.where( - schedule_timesteps == timestep_val, size=1, fill_value=-1 - )[0] - - step_index = jnp.where( - index_candidates[0] == -1, # No match found - len(schedule_timesteps) - 1, # Default to last index - index_candidates[0], - ) - return step_index - - def _init_step_index( - self, state: UniPCMultistepSchedulerState, timestep: Union[int, jnp.ndarray] - ) -> UniPCMultistepSchedulerState: - """Initializes the step_index counter for the scheduler.""" - if state.begin_index is None: - step_index_val = self.index_for_timestep(state, timestep) - return state.replace(step_index=step_index_val) - else: - return state.replace(step_index=state.begin_index) - - @partial(jax.jit, static_argnums=(0, 5)) # self is static_argnum=0 - def step( - self, - state: UniPCMultistepSchedulerState, - model_output: jnp.ndarray, # This is the direct output from the diffusion model (e.g., noise prediction) - timestep: Union[ - int, jnp.ndarray - ], # Current discrete timestep from the scheduler's sequence - sample: jnp.ndarray, # Current noisy sample (latent) - return_dict: bool = True, - generator: Optional[jax.random.PRNGKey] = None, # JAX random key - ) -> Union[FlaxUniPCMultistepSchedulerOutput, Tuple[jnp.ndarray]]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep UniPC. - """ - if state.timesteps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - timestep_scalar = jnp.array(timestep) - - # Initialize step_index if it's the first step - if state.step_index is None: - state = self._init_step_index(state, timestep_scalar) - - # Determine if corrector should be used - use_corrector = ( - (state.step_index > 0) - & ( - ~jnp.isin( - state.step_index - 1, jnp.array(self.config.disable_corrector) - ) - ) - & (state.last_sample is not None) - ) + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() - # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type - model_output_for_history = self.convert_model_output( - state, model_output, sample - ) + def rb_loop_body(i, carry): + # Loop from i = 0 to order-1 + R, b, current_h_phi_k, factorial_val = carry - # Apply corrector if applicable - sample = jax.lax.cond( - use_corrector, - lambda: self.multistep_uni_c_bh_update( - state=state, - this_model_output=model_output_for_history, - last_sample=state.last_sample, - this_sample=sample, - order=state.this_order, - ), - lambda: sample, - ) + R = R.at[i].set(jnp.power(rks, i)) + b = b.at[i].set(current_h_phi_k * factorial_val / B_h) - # Update history buffers (model_outputs and timestep_list) - # Shift existing elements to the left and add new one at the end. - # `state.model_outputs` and `state.timestep_list` are fixed-size arrays. - # Example: - # t0:[None,...,model_output0] - # t1:[None,..model_output0,model_output1] - # ... - # tn:[model_output0,model_output1,...,model_output_n] - def step_idx0_branch(): - updated_model_outputs_history = state.model_outputs.at[-1].set( - model_output_for_history - ) - updated_timestep_list_history = state.timestep_list.at[-1].set( - timestep_scalar - ) - return updated_model_outputs_history, updated_timestep_list_history - - def non_step_idx0_branch(): - updated_model_outputs_history = jnp.roll( - state.model_outputs, shift=-1, axis=0 - ) - updated_model_outputs_history = updated_model_outputs_history.at[-1].set( - model_output_for_history - ) - - updated_timestep_list_history = jnp.roll(state.timestep_list, shift=-1) - updated_timestep_list_history = updated_timestep_list_history.at[-1].set( - timestep_scalar - ) - return updated_model_outputs_history, updated_timestep_list_history - - updated_model_outputs_history, updated_timestep_list_history = jax.lax.cond( - state.step_index == 0, step_idx0_branch, non_step_idx0_branch - ) - state = state.replace( - model_outputs=updated_model_outputs_history, - timestep_list=updated_timestep_list_history, - ) + # Conditionally update phi_k and factorial for the next iteration + def update_fn(vals): + # This branch is taken if i < order - 1 + _h_phi_k, _fac = vals + next_fac = _fac * (i + 2) + next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac + return next_h_phi_k, next_fac - # Determine the order for the current step (warmup phase logic) - this_order = jnp.where( - self.config.lower_order_final, - jnp.minimum( - self.config.solver_order, len(state.timesteps) - state.step_index - ), - self.config.solver_order, - ) + current_h_phi_k, factorial_val = jax.lax.cond( + i < order - 1, + update_fn, # If true, update values + lambda vals: vals, # If false, pass through + (current_h_phi_k, factorial_val), + ) + return R, b, current_h_phi_k, factorial_val + + # Pre-allocate R and b to max size + R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype) + b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + + # Initialize loop carriers + init_h_phi_k = h_phi_1 / hh - 1.0 + init_factorial = 1.0 + + R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial)) + + if len(D1s) > 0: + D1s = jnp.stack(D1s, axis=1) # (B, K, C, H, W) + + def solve_for_rhos(R_mat, b_vec, current_order): + # Create a mask to select the first `current_order` elements + mask = jnp.arange(self.config.solver_order) < current_order + mask_2d = mask[:, None] & mask[None, :] + + # Pad R with identity and b with zeros to create a safe, full-sized system + R_safe = jnp.where(mask_2d, R_mat, jnp.eye(self.config.solver_order, dtype=R_mat.dtype)) + b_safe = jnp.where(mask, b_vec, 0.0) + + # Solve the full-size system and mask the result + solved_rhos = jnp.linalg.solve(R_safe, b_safe) + return jnp.where(mask, solved_rhos, 0.0) + + rhos_c_order1 = jnp.zeros(self.config.solver_order, dtype=x_t.dtype).at[0].set(0.5) + rhos_c_general = solve_for_rhos(R, b, order) + rhos_c = jnp.where(order == 1, rhos_c_order1, rhos_c_general) + + D1_t = model_t - m0 + + corr_res = jax.lax.cond( + order > 1, + lambda _: (jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)), + lambda _: jnp.zeros_like(D1_t), + operand=None, + ) + + final_rho = jnp.dot( + rhos_c, + jax.nn.one_hot(order - 1, self.config.solver_order, dtype=rhos_c.dtype), + ) + + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + final_rho * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t) + + return x_t.astype(x.dtype) + + def index_for_timestep( + self, + state: UniPCMultistepSchedulerState, + timestep: Union[int, jnp.ndarray], + schedule_timesteps: Optional[jnp.ndarray] = None, + ) -> int: + """ "Gets the step_index for timestep.""" + if schedule_timesteps is None: + schedule_timesteps = state.timesteps + + # QUINN!! + # timestep_val = ( + # timestep.item() + # if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0 + # else timestep + # ) + timestep_val = timestep + + index_candidates = jnp.where(schedule_timesteps == timestep_val, size=1, fill_value=-1)[0] + + step_index = jnp.where( + index_candidates[0] == -1, # No match found + len(schedule_timesteps) - 1, # Default to last index + index_candidates[0], + ) + return step_index + + def _init_step_index( + self, state: UniPCMultistepSchedulerState, timestep: Union[int, jnp.ndarray] + ) -> UniPCMultistepSchedulerState: + """Initializes the step_index counter for the scheduler.""" + if state.begin_index is None: + step_index_val = self.index_for_timestep(state, timestep) + return state.replace(step_index=step_index_val) + else: + return state.replace(step_index=state.begin_index) + + @partial(jax.jit, static_argnums=(0, 5)) # self is static_argnum=0 + def step( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, # This is the direct output from the diffusion model (e.g., noise prediction) + timestep: Union[int, jnp.ndarray], # Current discrete timestep from the scheduler's sequence + sample: jnp.ndarray, # Current noisy sample (latent) + return_dict: bool = True, + generator: Optional[jax.random.PRNGKey] = None, # JAX random key + ) -> Union[FlaxUniPCMultistepSchedulerOutput, Tuple[jnp.ndarray]]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + """ + if state.timesteps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + timestep_scalar = jnp.array(timestep) + + # Initialize step_index if it's the first step + if state.step_index is None: + state = self._init_step_index(state, timestep_scalar) - # Warmup for multistep: `this_order` can't exceed `lower_order_nums + 1` - new_this_order = jnp.minimum(this_order, state.lower_order_nums + 1) - state = state.replace(this_order=new_this_order) + # Determine if corrector should be used + use_corrector = ( + (state.step_index > 0) + & (~jnp.isin(state.step_index - 1, jnp.array(self.config.disable_corrector))) + & (state.last_sample is not None) + ) - # Store current sample as `last_sample` for the *next* step's corrector - state = state.replace(last_sample=sample) + # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type + model_output_for_history = self.convert_model_output(state, model_output, sample) - # UniP predictor step - prev_sample = self.multistep_uni_p_bh_update( + # Apply corrector if applicable + sample = jax.lax.cond( + use_corrector, + lambda: self.multistep_uni_c_bh_update( state=state, - model_output=model_output, - sample=sample, + this_model_output=model_output_for_history, + last_sample=state.last_sample, + this_sample=sample, order=state.this_order, - ) + ), + lambda: sample, + ) - # Update lower_order_nums for warmup - new_lower_order_nums = jnp.where( - state.lower_order_nums < self.config.solver_order, - state.lower_order_nums + 1, - state.lower_order_nums, - ) - state = state.replace(lower_order_nums=new_lower_order_nums) - # Upon completion, increase step index by one - state = state.replace(step_index=state.step_index + 1) - - # Return the updated sample and state - if not return_dict: - return (prev_sample, state) - - return FlaxUniPCMultistepSchedulerOutput(prev_sample=prev_sample, state=state) - - def scale_model_input( - self, state: UniPCMultistepSchedulerState, sample: jnp.ndarray, *args, **kwargs - ) -> jnp.ndarray: - """ - UniPC does not scale model input, so it returns the sample unchanged. - """ - return sample - - def add_noise( - self, - state: UniPCMultistepSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return add_noise_common(state.common, original_samples, noise, timesteps) - - def _sigma_to_alpha_sigma_t(self, sigma): - if self.config.use_flow_sigmas: - alpha_t = 1 - sigma - sigma_t = sigma - else: - alpha_t = 1 / ((sigma**2 + 1) ** 0.5) - sigma_t = sigma * alpha_t - - return alpha_t, sigma_t - - def __len__(self) -> int: - return self.config.num_train_timesteps + # Update history buffers (model_outputs and timestep_list) + # Shift existing elements to the left and add new one at the end. + # `state.model_outputs` and `state.timestep_list` are fixed-size arrays. + # Example: + # t0:[None,...,model_output0] + # t1:[None,..model_output0,model_output1] + # ... + # tn:[model_output0,model_output1,...,model_output_n] + def step_idx0_branch(): + updated_model_outputs_history = state.model_outputs.at[-1].set(model_output_for_history) + updated_timestep_list_history = state.timestep_list.at[-1].set(timestep_scalar) + return updated_model_outputs_history, updated_timestep_list_history + + def non_step_idx0_branch(): + updated_model_outputs_history = jnp.roll(state.model_outputs, shift=-1, axis=0) + updated_model_outputs_history = updated_model_outputs_history.at[-1].set(model_output_for_history) + + updated_timestep_list_history = jnp.roll(state.timestep_list, shift=-1) + updated_timestep_list_history = updated_timestep_list_history.at[-1].set(timestep_scalar) + return updated_model_outputs_history, updated_timestep_list_history + + updated_model_outputs_history, updated_timestep_list_history = jax.lax.cond( + state.step_index == 0, step_idx0_branch, non_step_idx0_branch + ) + state = state.replace( + model_outputs=updated_model_outputs_history, + timestep_list=updated_timestep_list_history, + ) + + # Determine the order for the current step (warmup phase logic) + this_order = jnp.where( + self.config.lower_order_final, + jnp.minimum(self.config.solver_order, len(state.timesteps) - state.step_index), + self.config.solver_order, + ) + + # Warmup for multistep: `this_order` can't exceed `lower_order_nums + 1` + new_this_order = jnp.minimum(this_order, state.lower_order_nums + 1) + state = state.replace(this_order=new_this_order) + + # Store current sample as `last_sample` for the *next* step's corrector + state = state.replace(last_sample=sample) + + # UniP predictor step + prev_sample = self.multistep_uni_p_bh_update( + state=state, + model_output=model_output, + sample=sample, + order=state.this_order, + ) + + # Update lower_order_nums for warmup + new_lower_order_nums = jnp.where( + state.lower_order_nums < self.config.solver_order, + state.lower_order_nums + 1, + state.lower_order_nums, + ) + state = state.replace(lower_order_nums=new_lower_order_nums) + # Upon completion, increase step index by one + state = state.replace(step_index=state.step_index + 1) + + # Return the updated sample and state + if not return_dict: + return (prev_sample, state) + + return FlaxUniPCMultistepSchedulerOutput(prev_sample=prev_sample, state=state) + + def scale_model_input(self, state: UniPCMultistepSchedulerState, sample: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: + """ + UniPC does not scale model input, so it returns the sample unchanged. + """ + return sample + + def add_noise( + self, + state: UniPCMultistepSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = sigma + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + def __len__(self) -> int: + return self.config.num_train_timesteps diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index be4d44f2c..3b013b791 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -37,7 +37,16 @@ def setUp(self): def test_splash_attention(self): """Test numerics of splash attention are equivalent to dot_product""" - pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml")], unittest=True) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base21.yml"), + 'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,' + '"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,' + '"block_q_dq": 512, "block_kv_dq": 512}', + ], + unittest=True, + ) config = pyconfig.config batch = 8 @@ -47,7 +56,6 @@ def test_splash_attention(self): key1, key2 = jax.random.split(jax.random.PRNGKey(0)) x = jax.random.normal(key1, (batch, length, heads * head_depth)) - dot_product_attention = FlaxAttention( heads * head_depth, heads, @@ -64,9 +72,16 @@ def test_splash_attention(self): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - + flash_block_sizes = max_utils.get_flash_block_sizes(config) splash_attention = FlaxAttention( - heads * head_depth, heads, head_depth, split_head_dim=True, attention_kernel="flash", mesh=mesh, dtype=jnp.bfloat16 + heads * head_depth, + heads, + head_depth, + split_head_dim=True, + attention_kernel="flash", + mesh=mesh, + dtype=jnp.bfloat16, + flash_block_sizes=flash_block_sizes, ) params = splash_attention.init(key2, x)["params"] diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py new file mode 100644 index 000000000..17741191a --- /dev/null +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -0,0 +1,286 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import jax +import jax.numpy as jnp +import pytest +import unittest +from absl.testing import absltest +from flax import nnx +from jax.sharding import Mesh + +from .. import pyconfig +from ..max_utils import (create_device_mesh, get_flash_block_sizes) +from ..models.wan.transformers.transformer_wan import ( + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, + WanModel, +) +from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection +from ..models.normalization_flax import FP32LayerNorm +from ..models.attention_flax import FlaxWanAttention + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class WanTransformerTest(unittest.TestCase): + + def setUp(self): + WanTransformerTest.dummy_data = {} + + def test_rotary_pos_embed(self): + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) + dummy_output = wan_rot_embed(dummy_hidden_states) + assert dummy_output.shape == (1, 1, 75600, 64) + + def test_nnx_pixart_alpha_text_projection(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dummy_caption = jnp.ones((1, 512, 4096)) + layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) + dummy_output = layer(dummy_caption) + dummy_output.shape == (1, 512, 5120) + + def test_nnx_timestep_embedding(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + dummy_sample = jnp.ones((1, 256)) + layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) + dummy_output = layer(dummy_sample) + assert dummy_output.shape == (1, 5120) + + def test_fp32_layer_norm(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size = 1 + dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) + # expected same output shape with same dtype + layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) + dummy_output = layer(dummy_hidden_states) + assert dummy_output.shape == dummy_hidden_states.shape + + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + def test_wan_time_text_embedding(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size = 1 + dim = 5120 + time_freq_dim = 256 + time_proj_dim = 30720 + text_embed_dim = 4096 + layer = WanTimeTextImageEmbedding( + rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim + ) + + dummy_timestep = jnp.ones(batch_size) + + encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) + dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( + dummy_timestep, dummy_encoder_hidden_states + ) + assert temb.shape == (batch_size, dim) + assert timestep_proj.shape == (batch_size, time_proj_dim) + assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + + def test_wan_block(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + + dim = 5120 + ffn_dim = 13824 + num_heads = 40 + qk_norm = "rms_norm_across_heads" + cross_attn_norm = True + eps = 1e-6 + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_dim = 75600 + + # for rotary post embed. + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + assert dummy_rotary_emb.shape == (batch_size, 1, hidden_dim, 64) + + # for transformer block + dummy_hidden_states = jnp.ones((batch_size, hidden_dim, dim)) + + dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) + + dummy_temb = jnp.ones((batch_size, 6, dim)) + + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + with mesh: + dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) + assert dummy_output.shape == dummy_hidden_states.shape + + def test_wan_attention(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + dummy_hidden_states_shape = (batch_size, 75600, query_dim) + + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + with mesh: + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape + + # dot product + try: + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + except NotImplementedError: + pass + + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + def test_wan_model(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, channels, frames, height, width) + dummy_hidden_states = jnp.ones(hidden_states_shape) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + wan_model = WanModel( + rngs=rngs, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + dummy_timestep = jnp.ones((batch_size)) + dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) + with mesh: + dummy_output = wan_model( + hidden_states=dummy_hidden_states, + timestep=dummy_timestep, + encoder_hidden_states=dummy_encoder_hidden_states, + is_uncond=jnp.array(True, dtype=jnp.bool_), + slg_mask=jnp.zeros(40, dtype=jnp.bool_), + ) + assert dummy_output.shape == hidden_states_shape + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 7d750c8bb..7b131e7fb 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -14,6 +14,7 @@ limitations under the License. """ +import os import functools import torch import torch.nn as nn @@ -21,6 +22,11 @@ import jax import jax.numpy as jnp from flax import nnx +from jax.sharding import Mesh +from .. import pyconfig +from ..max_utils import ( + create_device_mesh, +) import numpy as np import unittest from absl.testing import absltest @@ -41,6 +47,8 @@ from ..utils import load_video from ..video_processor import VideoProcessor +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + CACHE_T = 2 @@ -249,6 +257,17 @@ def test_wan_resample(self): def test_3d_conv(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 in_channels = 3 @@ -269,7 +288,8 @@ def test_3d_conv(self): out_channels=out_channels, kernel_size=(kernel_d, kernel_h, kernel_w), padding=(padding_d, padding_h, padding_w), - rngs=rngs, # Pass rngs for initialization + rngs=rngs, # Pass rngs for initialization, + mesh=mesh, ) # --- Test Case 1: No Cache --- @@ -289,6 +309,16 @@ def test_3d_conv(self): def test_wan_residual(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) # --- Test Case 1: same in/out dim --- in_dim = out_dim = 96 batch = 1 @@ -299,11 +329,7 @@ def test_wan_residual(self): input_shape = (batch, t, height, width, dim) expected_output_shape = (batch, t, height, width, dim) - wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, - ) + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape @@ -313,11 +339,7 @@ def test_wan_residual(self): out_dim = 196 expected_output_shape = (batch, t, height, width, out_dim) - wan_residual_block = WanResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - rngs=rngs, - ) + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape @@ -339,13 +361,23 @@ def test_wan_attention(self): def test_wan_midblock(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) batch = 1 t = 1 dim = 384 height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock(dim=dim, rngs=rngs) + wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) output = wan_midblock(dummy_input) assert output.shape == input_shape @@ -353,6 +385,16 @@ def test_wan_midblock(self): def test_wan_decode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] @@ -367,6 +409,7 @@ def test_wan_decode(self): num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, + mesh=mesh, ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 @@ -386,6 +429,16 @@ def test_wan_decode(self): def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] @@ -400,6 +453,7 @@ def test_wan_encode(self): num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, + mesh=mesh, ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 @@ -418,10 +472,20 @@ def vae_encode(video, wan_vae, vae_cache, key): latent = latent.latent_dist.sample(key) return latent - pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers" key = jax.random.key(0) rngs = nnx.Rngs(key) - wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) @@ -435,7 +499,7 @@ def vae_encode(video, wan_vae, vae_cache, key): graphdef, state = nnx.split(wan_vae) params = state.to_pure_dict() # This replaces random params with the model. - params = load_wan_vae(pretrained_model_name_or_path, params, "cpu") + params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) wan_vae = nnx.merge(graphdef, params) diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py new file mode 100644 index 000000000..8d4774987 --- /dev/null +++ b/src/maxdiffusion/train_wan.py @@ -0,0 +1,41 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from typing import Sequence + +import jax +from absl import app +from maxdiffusion import max_logging, pyconfig +from maxdiffusion.train_utils import validate_train_config + + +def train(config): + from maxdiffusion.trainers.wan_trainer import WanTrainer + + trainer = WanTrainer(config) + trainer.start_training() + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + config = pyconfig.config + validate_train_config(config) + max_logging.log(f"Found {jax.device_count()} devices.") + train(config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py new file mode 100644 index 000000000..3740e2cf1 --- /dev/null +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -0,0 +1,170 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import datetime +import functools +import numpy as np +import jax.numpy as jnp +import jax +import jax.tree_util as jtu +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from ..schedulers import FlaxEulerDiscreteScheduler +from .. import max_utils, max_logging, train_utils, maxdiffusion_utils +from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) +from multihost_dataloading import _form_global_array + + +class WanTrainer(WanCheckpointer): + + def __init__(self, config): + WanCheckpointer.__init__(self, config, WAN_CHECKPOINT) + if config.train_text_encoder: + raise ValueError("this script currently doesn't support training text_encoders") + + self.global_batch_size = self.config.per_device_batch_size * jax.device_count() + + def post_training_steps(self, pipeline, params, train_states, msg=""): + pass + + def create_scheduler(self, pipeline, params): + # TODO - set right scheduler + noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 + ) + noise_scheduler_state = noise_scheduler.set_timesteps( + state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux" + ) + return noise_scheduler, noise_scheduler_state + + def calculate_tflops(self, pipeline): + max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") + return 0 + + def load_dataset(self, pipeline): + # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 + # Image pre-training - txt2img 256px + # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16 + # Image-video joint training - stage 2. 480px images and 480px 5 sec videos at fps=16 + # Image-video joint training - stage final. 720px images and 720px 5 sec videos at fps=16 + # prompt embeds shape: (1, 512, 4096) + # For now, we will pass the same latents over and over + # TODO - create a dataset + return maxdiffusion_utils.get_dummy_wan_inputs(self.config, pipeline, self.global_batch_size) + + def start_training(self): + + pipeline = self.load_checkpoint() + del pipeline.vae + dummy_inputs = self.load_dataset(pipeline) + mesh = pipeline.mesh + optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) + dummy_inputs = tuple( + [jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs] + ) + self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs) + + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): + + graphdef, state = nnx.split((pipeline.transformer, optimizer)) + writer = max_utils.initialize_summary_writer(self.config) + num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) + max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) + max_utils.add_config_to_summary_writer(self.config, writer) + + if jax.process_index() == 0: + max_logging.log("***** Running training *****") + max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}") + max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") + + state = state.to_pure_dict() + p_train_step = jax.jit( + train_step, + donate_argnums=(0,), + ) + rng = jax.random.key(self.config.seed) + start_step = 0 + last_step_completion = datetime.datetime.now() + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None + running_gcs_metrics = [] if self.config.gcs_metrics else None + first_profiling_step = self.config.skip_first_n_steps_for_profiler + if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = np.clip( + first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 + ) + # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint. + start_step = 0 + per_device_tflops = self.calculate_tflops(pipeline) + + for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): + state, train_metric, rng = p_train_step(state, graphdef, data, rng) + + new_time = datetime.datetime.now() + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + + train_utils.record_scalar_metrics( + train_metric, new_time - last_step_completion, per_device_tflops, learning_rate_scheduler(step) + ) + if self.config.write_metrics: + train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + last_step_completion = new_time + + +def train_step(state, graphdef, data, rng): + return step_optimizer(graphdef, state, data, rng) + + +def step_optimizer(graphdef, state, data, rng): + _, new_rng = jax.random.split(rng) + + def loss_fn(model): + latents, prompt_embeds, timesteps = data + + noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) + + # TODO - add noise here + + model_pred = model( + hidden_states=noise, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + is_uncond=jnp.array(False, dtype=jnp.bool_), + slg_mask=jnp.zeros(1, dtype=jnp.bool_), + ) + target = noise - latents + loss = (target - model_pred) ** 2 + loss = jnp.mean(loss) + # breakpoint() + return loss + + model, optimizer = nnx.merge(graphdef, state) + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) + state = nnx.state((model, optimizer)) + state = state.to_pure_dict() + metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} + return state, metrics, new_rng