diff --git a/ b/ new file mode 100644 index 000000000..e69de29bb diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 728d2f2e3..05f332fb7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/requirements.txt b/requirements.txt index 879b62d54..2ccbd88ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,8 @@ pytest==8.2.2 tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 +git+https://github.com/Lightricks/LTX-Video +git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax opencv-python-headless==4.10.0.84 orbax-checkpoint==0.10.3 tokenizers==0.21.0 diff --git a/setup.sh b/setup.sh index a62e16918..15932df86 100644 --- a/setup.sh +++ b/setup.sh @@ -112,4 +112,4 @@ else fi # Install maxdiffusion -pip3 install -U . || echo "Failed to install maxdiffusion" >&2 +pip3 install -U . || echo "Failed to install maxdiffusion" >&2 \ No newline at end of file diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed682..42e50d775 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -374,6 +374,7 @@ _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( [ @@ -453,6 +454,7 @@ from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..1a8d12ef0 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,11 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + if checkpoint_item == "ltxvid_transformer": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 000000000..fce674f2c --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,99 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: '' +config_path: '' +save_config_to_gcs: False + +#Checkpoints +text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax" +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +frame_rate: 30 +max_sequence_length: 512 +sampler: "from_checkpoint" + +# Generation parameters +pipeline_type: multi-scale +prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. " +#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +height: 512 +width: 512 +num_frames: 88 +flow_shift: 5.0 +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" +prompt_enhancement_words_threshold: 120 +stg_mode: "attention_values" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +seed: 10 + + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + skip_initial_inference_steps: 0 + cfg_star_rescale: True + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + skip_final_inference_steps: 0 + cfg_star_rescale: True + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_heads', 'fsdp'], + ['activation_batch', 'data'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'fsdp'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False +learning_rate_schedule_steps: -1 +max_train_steps: 500 +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 000000000..553d6373e --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,161 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import numpy as np +from absl import app +from typing import Sequence +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline +from maxdiffusion import pyconfig, max_logging +import imageio +from datetime import datetime +import os +import time +from pathlib import Path + + +def calculate_padding( + source_height: int, source_width: int, target_height: int, target_width: int +) -> tuple[int, int, int, int]: + + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width + + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding + + +def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: + # Remove non-letters and convert to lowercase + clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace()) + + # Split into words + words = clean_text.split() + + # Build result string keeping track of length + result = [] + current_length = 0 + + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) + + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break + + return "-".join(result) + + +def get_unique_filename( + base: str, + ext: str, + prompt: str, + resolution: tuple[int, int, int], + dir: Path, + endswith=None, + index_range=1000, +) -> Path: + base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + for i in range(index_range): + filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.") + + +def run(config): + height_padded = ((config.height - 1) // 32 + 1) * 32 + width_padded = ((config.width - 1) // 32 + 1) * 32 + num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 + padding = calculate_padding(config.height, config.width, height_padded, width_padded) + prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold + prompt_word_count = len(config.prompt.split()) + enhance_prompt = prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold + + pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt) + if config.pipeline_type == "multi-scale": + pipeline = LTXMultiScalePipeline(pipeline) + s0 = time.perf_counter() + images = pipeline( + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + is_video=True, + output_type="pt", + config=config, + enhance_prompt=enhance_prompt, + seed=config.seed, + ) + max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.") + + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right] + output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + output_dir.mkdir(parents=True, exist_ok=True) + + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = config.frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=config.prompt, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + output_filename = get_unique_filename( + f"video_output_{i}", + ".mp4", + prompt=config.prompt, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fe6cc09af..96b60426d 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -405,7 +405,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == "ltxvid_transformer": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24e..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -13,9 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING - -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available - +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,6 +30,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py new file mode 100644 index 000000000..285b6e81c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py new file mode 100644 index 000000000..7206893d0 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py @@ -0,0 +1,74 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Tuple, Union + +import torch +import torch.nn as nn + + +class CausalConv3d(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode, + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py new file mode 100644 index 000000000..dd94dfbb6 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py @@ -0,0 +1,1300 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union, List +from pathlib import Path + +import torch +import numpy as np +from einops import rearrange +from torch import nn +from diffusers.utils import logging +import torch.nn.functional as F +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from safetensors import safe_open + + +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper +from maxdiffusion.models.ltx_video.transformers.attention import Attention +from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + VAE_KEYS_RENAME_DICT, +) + +PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CausalVideoAutoencoder(AutoencoderKLWrapper): + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if pretrained_model_name_or_path.is_dir() and (pretrained_model_name_or_path / "autoencoder.pth").exists(): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json" + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)} + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"])) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = std_of_means + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = mean_of_means + + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = pretrained_model_name_or_path / "vae" / "diffusion_pytorch_model.safetensors" + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str(pretrained_model_name_or_path).endswith(".safetensors"): + state_dict = {} + with safe_open(pretrained_model_name_or_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) + return video_vae + + @staticmethod + def from_config(config): + assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none") + use_quant_conv = config.get("use_quant_conv", True) + normalize_latent_channels = config.get("normalize_latent_channels", False) + + if use_quant_conv and latent_log_var in ["uniform", "constant"]: + raise ValueError(f"latent_log_var={latent_log_var} requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + base_channels=config.get("encoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), + base_channels=config.get("decoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + dims = config["dims"] + return CausalVideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + normalize_latent_channels=normalize_latent_channels, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="CausalVideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, + out_channels=self.decoder.conv_out.out_channels // self.decoder.patch_size**2, + latent_channels=self.decoder.conv_in.in_channels, + encoder_blocks=self.encoder.blocks_desc, + decoder_blocks=self.decoder.blocks_desc, + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + causal_decoder=self.decoder.causal, + timestep_conditioning=self.decoder.timestep_conditioning, + normalize_latent_channels=self.normalize_latent_channels, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def spatial_downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_space", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) + * self.encoder.patch_size + ) + + @property + def temporal_downscale_factor(self): + return 2 ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_time", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if any([key.startswith("vae.") for key in state_dict.keys()]): # noqa: C419 + state_dict = {key.replace("vae.", ""): value for key, value in state_dict.items() if key.startswith("vae.")} + ckpt_state_dict = {key: value for key, value in state_dict.items() if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)} + + model_keys = set(name for name, _ in self.named_modules()) # noqa: C401 + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + converted_state_dict = {} + for key, value in ckpt_state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + key_prefix = ".".join(key.split(".")[:-1]) + if "norm" in key and key_prefix not in model_keys: + logger.info(f"Removing key {key} from state_dict as it is not present in the model") + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + data_dict = { + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + for key, value in state_dict.items() + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + if len(data_dict) > 0: + self.register_buffer("std_of_means", data_dict["std-of-means"]) + self.register_buffer( + "mean_of_means", + data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"])), + ) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + def set_use_tpu_flash_attention(self): + for block in self.decoder.up_blocks: + if isinstance(block, UNetMidBlock3D) and block.attention_blocks: + for attention_block in block.attention_blocks: + attention_block.set_use_tpu_flash_attention() + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample).to(torch.bfloat16) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + if block_name == "compress_all": + output_channel = output_channel * block_params.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.last_scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + def forward( + self, + sample: torch.FloatTensor, + target_shape, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + if self.timestep_conditioning: + assert timestep is not None, "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier + + for up_block in self.up_blocks: + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)(sample, causal=self.causal, timestep=scaled_timestep) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + sample = self.conv_norm_out(sample).to(torch.bfloat16) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + attention_head_dim (`int`, *optional*, defaults to -1): + The dimension of the attention head. If -1, no attention is used. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + attention_head_dim: int = -1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + for _ in range(num_layers) + ] + ) + + self.attention_blocks = None + + if attention_head_dim > 0: + if attention_head_dim > in_channels: + raise ValueError("attention_head_dim must be less than or equal to in_channels") + + self.attention_blocks = nn.ModuleList( + [ + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=True, + out_bias=True, + qk_norm="rms_norm", + residual_connection=True, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert timestep is not None, "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + if self.attention_blocks: + for resnet, attention in zip(self.res_blocks, self.attention_blocks): + hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed) + + # Reshape the hidden states to be (batch_size, frames * height * width, channel) + batch_size, channel, frames, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, frames * height * width).transpose(1, 2) + + if attention.use_tpu_flash_attention: + # Pad the second dimension to be divisible by block_k_major (block in flash attention) + seq_len = hidden_states.shape[1] + block_k_major = 512 + pad_len = (block_k_major - seq_len % block_k_major) % block_k_major + if pad_len > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len), "constant", 0) + + # Create a mask with ones for the original sequence length and zeros for the padded indexes + mask = torch.ones( + (hidden_states.shape[0], seq_len), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if pad_len > 0: + mask = F.pad(mask, (0, pad_len), "constant", 0) + + hidden_states = attention( + hidden_states, + attention_mask=(None if not attention.use_tpu_flash_attention else mask), + ) + + if attention.use_tpu_flash_attention: + # Remove the padding + if pad_len > 0: + hidden_states = hidden_states[:, :-pad_len, :] + + # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, frames, height, width) + else: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * np.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // np.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + + def __init__( + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", + ): + super().__init__() + self.stride = stride + self.out_channels = np.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = self.pixel_shuffle(x) + num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = self.pixel_shuffle(x) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) if in_channels != out_channels else nn.Identity() + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def _feed_spatial_noise(self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states).to(torch.bfloat16) + if self.timestep_conditioning: + assert timestep is not None, "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[None, ..., None, None, None] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise(hidden_states, self.per_channel_scale1) + + hidden_states = self.norm2(hidden_states).to(torch.bfloat16) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise(hidden_states, self.per_channel_scale2) + + input_tensor = self.norm3(input_tensor).to(torch.bfloat16) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_demo_config( + latent_channels: int = 64, +): + encoder_blocks = [ + ("res_x", {"num_layers": 2}), + ("compress_space_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ] + decoder_blocks = [ + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ] + return { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "encoder_blocks": encoder_blocks, + "decoder_blocks": decoder_blocks, + "latent_channels": latent_channels, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + "spatial_padding_mode": "replicate", + } + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_demo_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = CausalVideoAutoencoder.from_config(config) + + print(video_autoencoder) + video_autoencoder.eval() + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 17, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + + timestep = torch.ones(input_videos.shape[0]) * 0.1 + reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape, timestep=timestep).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Validate that single image gets treated the same way as first frame + input_image = input_videos[:, :, :1, :, :] + image_latent = video_autoencoder.encode(input_image).latent_dist.mode() + _ = video_autoencoder.decode(image_latent, target_shape=image_latent.shape, timestep=timestep).sample + + first_frame_latent = latent[:, :, :1, :, :] + + assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) + # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py new file mode 100644 index 000000000..d0be897e8 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py @@ -0,0 +1,102 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Tuple, Union + +import torch + +from maxdiffusion.models.ltx_video.autoencoders.dual_conv3d import DualConv3d +from maxdiffusion.models.ltx_video.autoencoders.causal_conv3d import CausalConv3d + + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, + spatial_padding_mode="zeros", + temporal_padding_mode="zeros", +): + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py new file mode 100644 index 000000000..c0a4db3eb --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py @@ -0,0 +1,224 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose(output_conv3d, output_2d, atol=1e-6), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py new file mode 100644 index 000000000..56a6c2d1b --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py @@ -0,0 +1,210 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Optional, Union +from pathlib import Path +import os +import json + +import torch +import torch.nn as nn +from einops import rearrange +from diffusers import ConfigMixin, ModelMixin +from safetensors.torch import safe_open + +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND + + +class ResBlock(nn.Module): + + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(".safetensors"): + state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + config = json.loads(metadata["config"]) + with torch.device("meta"): + latent_upsampler = LatentUpsampler.from_config(config) + latent_upsampler.load_state_dict(state_dict, assign=True) + return latent_upsampler + + +if __name__ == "__main__": + latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) + print(latent_upsampler) + total_params = sum(p.numel() for p in latent_upsampler.parameters()) + print(f"Total number of parameters: {total_params:,}") + latent = torch.randn(1, 128, 9, 16, 16) + upsampled_latent = latent_upsampler(latent) + print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py new file mode 100644 index 000000000..422df50f5 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py @@ -0,0 +1,29 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import torch +from torch import nn + + +class PixelNorm(nn.Module): + + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py new file mode 100644 index 000000000..7bd4761c4 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py @@ -0,0 +1,50 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import torch.nn as nn +from einops import rearrange + + +class PixelShuffleND(nn.Module): + + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py new file mode 100644 index 000000000..17823c5d0 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py @@ -0,0 +1,360 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Optional, Union + +import torch +import inspect +import math +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd + + +class AutoencoderKLWrapper(ModelMixin, ConfigMixin): + """Variational Autoencoder (VAE) model with KL loss. + + VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. + This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + + Args: + encoder (`nn.Module`): + Encoder module. + decoder (`nn.Module`): + Decoder module. + latent_channels (`int`, *optional*, defaults to 4): + Number of latent channels. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_channels: int = 4, + dims: int = 2, + sample_size=512, + use_quant_conv: bool = True, + normalize_latent_channels: bool = False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = encoder + self.use_quant_conv = use_quant_conv + self.normalize_latent_channels = normalize_latent_channels + + # pass init params to Decoder + quant_dims = 2 if dims == 2 else 3 + self.decoder = decoder + if use_quant_conv: + self.quant_conv = make_conv_nd(quant_dims, 2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = make_conv_nd(quant_dims, latent_channels, latent_channels, 1) + else: + self.quant_conv = nn.Identity() + self.post_quant_conv = nn.Identity() + + if normalize_latent_channels: + if dims == 2: + self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.Identity() + self.use_z_tiling = False + self.use_hw_tiling = False + self.dims = dims + self.z_sample_size = 1 + + self.decoder_params = inspect.signature(self.decoder.forward).parameters + + # only relevant if vae tiling is enabled + self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) + + def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): + self.tile_sample_min_size = sample_size + num_blocks = len(self.encoder.down_blocks) + self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) + self.tile_overlap_factor = overlap_factor + + def enable_z_tiling(self, z_sample_size: int = 8): + r""" + Enable tiling during VAE decoding. + + When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_z_tiling = z_sample_size > 1 + self.z_sample_size = z_sample_size + assert z_sample_size % 8 == 0 or z_sample_size == 1, f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." + + def disable_z_tiling(self): + r""" + Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_z_tiling = False + + def enable_hw_tiling(self): + r""" + Enable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = True + + def disable_hw_tiling(self): + r""" + Disable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = False + + def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + return moments + + def blend_z(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for z in range(blend_extent): + b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (1 - z / blend_extent) + b[:, :, z, :, :] * (z / blend_extent) + return b + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + tile_target_shape = ( + *target_shape[:3], + self.tile_sample_min_size, + self.tile_sample_min_size, + ) + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, target_shape=tile_target_shape) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def encode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + num_splits = z.shape[2] // self.z_sample_size + sizes = [self.z_sample_size] * num_splits + sizes = sizes + [z.shape[2] - sum(sizes)] if z.shape[2] - sum(sizes) > 0 else sizes + tiles = z.split(sizes, dim=2) + moments_tiles = [ + (self._hw_tiled_encode(z_tile, return_dict) if self.use_hw_tiling else self._encode(z_tile)) for z_tile in tiles + ] + moments = torch.cat(moments_tiles, dim=2) + + else: + moments = self._hw_tiled_encode(z, return_dict) if self.use_hw_tiling else self._encode(z) + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + _, c, _, _, _ = z.shape + z = torch.cat( + [ + self.latent_norm_out(z[:, : c // 2, :, :, :]), + z[:, c // 2 :, :, :, :], + ], + dim=1, + ) + elif isinstance(self.latent_norm_out, nn.BatchNorm2d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) + running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) + eps = self.latent_norm_out.eps + + z = z * torch.sqrt(running_var + eps) + running_mean + elif isinstance(self.latent_norm_out, nn.BatchNorm3d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + moments = self._normalize_latent_channels(moments) + return moments + + def _decode( + self, + z: torch.FloatTensor, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = self._unnormalize_latent_channels(z) + z = self.post_quant_conv(z) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) + else: + dec = self.decoder(z, target_shape=target_shape) + return dec + + def decode( + self, + z: torch.FloatTensor, + return_dict: bool = True, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + assert target_shape is not None, "target_shape must be provided for decoding" + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + reduction_factor = int( + self.encoder.patch_size_t * 2 ** (len(self.encoder.down_blocks) - 1 - math.sqrt(self.encoder.patch_size)) + ) + split_size = self.z_sample_size // reduction_factor + num_splits = z.shape[2] // split_size + + # copy target shape, and divide frame dimension (=2) by the context size + target_shape_split = list(target_shape) + target_shape_split[2] = target_shape[2] // num_splits + + decoded_tiles = [ + ( + self._hw_tiled_decode(z_tile, target_shape_split) + if self.use_hw_tiling + else self._decode(z_tile, target_shape=target_shape_split) + ) + for z_tile in torch.tensor_split(z, num_splits, dim=2) + ] + decoded = torch.cat(decoded_tiles, dim=2) + else: + decoded = ( + self._hw_tiled_decode( + z, target_shape + ) # z(1, 128, 1, 6, 8), need to be torch.Size([]) target (1, 3, 8, 192, 256), should be type Tensor + if self.use_hw_tiling + else self._decode( + z, target_shape=target_shape, timestep=timestep + ) # Tensor( 0.05) size torch.Size([]) + ) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Generator used to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, target_shape=sample.shape).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py new file mode 100644 index 000000000..d4c44024e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py @@ -0,0 +1,225 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Tuple +import torch +from diffusers import AutoencoderKL +from einops import rearrange +from torch import Tensor + +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.autoencoders.video_autoencoder import ( + Downsample3D, + VideoAutoencoder, +) + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def vae_encode( + media_items: Tensor, + vae: AutoencoderKL, + split_size: int = 1, + vae_per_channel_normalize=False, +) -> Tensor: + """ + Encodes media items (images or videos) into latent representations using a specified VAE model. + The function supports processing batches of images or video frames and can handle the processing + in smaller sub-batches if needed. + + Args: + media_items (Tensor): A torch Tensor containing the media items to encode. The expected + shape is (batch_size, channels, height, width) for images or (batch_size, channels, + frames, height, width) for videos. + vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, + pre-configured and loaded with the appropriate model weights. + split_size (int, optional): The number of sub-batches to split the input batch into for encoding. + If set to more than 1, the input media items are processed in smaller batches according to + this value. Defaults to 1, which processes all items in a single batch. + + Returns: + Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted + to match the input shape, scaled by the model's configuration. + + Examples: + >>> import torch + >>> from diffusers import AutoencoderKL + >>> vae = AutoencoderKL.from_pretrained('your-model-name') + >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. + >>> latents = vae_encode(images, vae) + >>> print(latents.shape) # Output shape will depend on the model's latent configuration. + + Note: + In case of a video, the function encodes the media item frame-by frame. + """ + is_video_shaped = media_items.dim() == 5 + batch_size, channels = media_items.shape[0:2] + + if channels != 3: + raise ValueError(f"Expects tensors with 3 channels, got {channels}.") + + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(media_items) % split_size != 0: + raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") + encode_bs = len(media_items) // split_size + # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] + latents = [] + if media_items.device.type == "xla": + xm.mark_step() + for image_batch in media_items.split(encode_bs): + latents.append(vae.encode(image_batch).latent_dist.sample()) + if media_items.device.type == "xla": + xm.mark_step() + latents = torch.cat(latents, dim=0) + else: + dist = vae.encode(media_items).latent_dist + latents = dist.sample() + + latents = normalize_latents(latents, vae, vae_per_channel_normalize) + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) + return latents + + +def vae_decode( # this function needs latents to be in tensor form + latents: Tensor, + vae: AutoencoderKL, + is_video: bool = True, + split_size: int = 1, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + is_video_shaped = latents.dim() == 5 + batch_size = latents.shape[0] + + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + latents = rearrange(latents, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(latents) % split_size != 0: + raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") + encode_bs = len(latents) // split_size + image_batch = [ + _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize, timestep) + for latent_batch in latents.split(encode_bs) + ] + images = torch.cat(image_batch, dim=0) + else: + images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize, timestep) + + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) + return images + + +def _run_decoder( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + *_, fl, hl, wl = latents.shape + temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) + latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep + + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + target_shape=( + 1, + 3, + fl * temporal_scale if is_video else 1, + hl * spatial_scale, + wl * spatial_scale, + ), + **vae_decode_kwargs, + )[0] + else: + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + )[0] + return image + + +def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: + if isinstance(vae, CausalVideoAutoencoder): + spatial = vae.spatial_downscale_factor + temporal = vae.temporal_downscale_factor + else: + down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)]) + spatial = vae.config.patch_size * 2**down_blocks + temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1 + + return (temporal, spatial, spatial) + + +def latent_to_pixel_coords(latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + vae (AutoencoderKL): The VAE model + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + + scale_factors = get_vae_size_scale_factor(vae) + causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix + pixel_coords = latent_to_pixel_coords_from_factors(latent_coords, scale_factors, causal_fix) + return pixel_coords + + +def latent_to_pixel_coords_from_factors(latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False) -> Tensor: + pixel_coords = latent_coords * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords + + +def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: + return ( + (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) + / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents * vae.config.scaling_factor + ) + + +def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: + return ( + latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents / vae.config.scaling_factor + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py new file mode 100644 index 000000000..ea46b4fbd --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py @@ -0,0 +1,111 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder +from maxdiffusion.models.ltx_video.autoencoders import causal_conv3d +from maxdiffusion.models.ltx_video.autoencoders.vae_encode import vae_encode, vae_decode + +import jax +from torchax import interop +from torchax import default_env + +# remove weight attribute to avoid error in JittableModule +# in the future, this will be fixed in ltxv public repo +delattr(causal_conv3d.CausalConv3d, "weight") + + +class TorchaxCausalVideoAutoencoder(interop.JittableModule): + + def __init__(self, vae: CausalVideoAutoencoder): + super().__init__(vae, extra_jit_args=dict(static_argnames=["split_size", "vae_per_channel_normalize"])) # noqa: C408 + + def encode(self, media_items: jax.Array, split_size: int = 1, vae_per_channel_normalize: bool = True) -> jax.Array: + if media_items.ndim != 5: + raise ValueError( + f"Expected media_items to have 5 dimensions (batch, channels, frames, height, width), but got {media_items.ndim} dimensions." + ) + num_frames = media_items.shape[2] + if (num_frames - 1) % 8 != 0: + raise ValueError( + f"Expected media_items to have a number of frames that is 1 + 8 * k for some integer k, but got {num_frames} frames." + ) + with default_env(): + media_items = interop.torch_view(media_items) + + output = self.functional_call( + self._vae_encoder_inner, + params=self.params, + buffers=self.buffers, + media_items=media_items, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + return interop.jax_view(output) + + def decode( + self, + latents: jax.Array, + timestep: jax.Array, + split_size: int = 1, + vae_per_channel_normalize: bool = True, + is_video: bool = True, + ) -> jax.Array: + with default_env(): + latents = interop.torch_view(latents) + timestep = interop.torch_view(timestep) + output = self.functional_call( + self._vae_decoder_inner, + params=self.params, + buffers=self.buffers, + latents=latents, + timestep=timestep, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + is_video=is_video, + ) + + return interop.jax_view(output) + + @staticmethod + def _vae_encoder_inner(model, media_items, split_size, vae_per_channel_normalize): + return vae_encode( + media_items=media_items, + vae=model, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + @staticmethod + def _vae_decoder_inner( + model, latents, timestep, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize: bool = False + ): + return vae_decode( + latents=latents, + vae=model, + is_video=is_video, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + timestep=timestep, + ) + + @staticmethod + def normalize_img(image): + return (image - 128) / 128 + + @staticmethod + def denormalize_img(image): + return (image * 128 + 128).clip(0, 255) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py new file mode 100644 index 000000000..8e6b67a43 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py @@ -0,0 +1,980 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional + +from diffusers.utils import logging + +from ltx_video.utils.torch_utils import Identity +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper + +logger = logging.get_logger(__name__) + + +class VideoAutoencoder(AutoencoderKLWrapper): + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + video_vae = cls.from_config(config) + video_vae.to(kwargs["torch_dtype"]) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + ckpt_state_dict = torch.load(model_local_path) + video_vae.load_state_dict(ckpt_state_dict) + + statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json" + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)} + video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) + video_vae.register_buffer( + "mean_of_means", + data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"])), + ) + + return video_vae + + @staticmethod + def from_config(config): + assert config["_class_name"] == "VideoAutoencoder", "config must have _class_name=VideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none") + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + dims = config["dims"] + return VideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="VideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // (self.encoder.patch_size_t * self.encoder.patch_size**2), + out_channels=self.decoder.conv_out.out_channels // (self.decoder.patch_size_t * self.decoder.patch_size**2), + latent_channels=self.decoder.conv_in.in_channels, + block_out_channels=[ + self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels for i in range(len(self.encoder.down_blocks)) + ], + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + patch_size_t=self.encoder.patch_size_t, + add_channel_padding=self.encoder.add_channel_padding, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def downscale_factor(self): + return self.encoder.downsample_factor + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + model_keys = set(name for name, _ in self.named_parameters()) # noqa: C401 + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + if "norm" in key and key not in model_keys: + logger.info(f"Removing key {key} from state_dict as it is not present in the model") + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + if add_channel_padding: + in_channels = in_channels * self.patch_size**3 + else: + in_channels = in_channels * self.patch_size_t * self.patch_size**2 + self.in_channels = in_channels + output_channel = block_out_channels[0] + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block and 2**i >= patch_size, + resnet_eps=1e-6, + downsample_padding=0, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6, + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd(dims, block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + @property + def downscale_factor(self): + return 2 ** len([block for block in self.down_blocks if isinstance(block.downsample, Downsample3D)]) * self.patch_size + + def forward(self, sample: torch.FloatTensor, return_features=False) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + downsample_in_time = sample.shape[2] != 1 + + # patchify + patch_size_t = self.patch_size_t if downsample_in_time else 1 + sample = patchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + if return_features: + features = [] + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample, downsample_in_time=downsample_in_time) + if return_features: + features.append(sample) + + sample = checkpoint_fn(self.mid_block)(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + if return_features: + features.append(sample[:, : self.latent_channels, ...]) + return sample, features + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + if add_channel_padding: + out_channels = out_channels * self.patch_size**3 + else: + out_channels = out_channels * self.patch_size_t * self.patch_size**2 + self.out_channels = out_channels + + self.conv_in = make_conv_nd( + dims, + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + dims=dims, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block and 2 ** (len(block_out_channels) - i - 1) > patch_size, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.up_blocks.append(up_block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd(dims, block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + upsample_in_time = sample.shape[2] < target_shape[2] + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = checkpoint_fn(self.mid_block)(sample) + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # un-patchify + patch_size_t = self.patch_size_t if upsample_in_time else 1 + sample = unpatchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + return sample + + +class DownEncoderBlock3D(nn.Module): + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 1, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_downsample: + self.downsample = Downsample3D( + dims, + out_channels, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsample = Identity() + + def forward(self, hidden_states: torch.FloatTensor, downsample_in_time) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.downsample(hidden_states, downsample_in_time=downsample_in_time) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_upsample: bool = True, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_upsample: + self.upsample = Upsample3D(dims=dims, channels=out_channels, out_channels=out_channels) + else: + self.upsample = Identity() + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.FloatTensor, upsample_in_time=True) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_layer == "group_norm": + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if norm_layer == "group_norm": + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.conv_shortcut = ( + make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Downsample3D(nn.Module): + + def __init__( + self, + dims, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + stride: int = 2 + self.padding = padding + self.in_channels = in_channels + self.dims = dims + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x, downsample_in_time=True): + conv = self.conv + if self.padding == 0: + if self.dims == 2: + padding = (0, 1, 0, 1) + else: + padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) + + x = functional.pad(x, padding, mode="constant", value=0) + + if self.dims == (2, 1) and not downsample_in_time: + return conv(x, skip_time_conv=True) + + return conv(x) + + +class Upsample3D(nn.Module): + """ + An upsampling layer for 3D tensors of shape (B, C, D, H, W). + + :param channels: channels in the inputs and outputs. + """ + + def __init__(self, dims, channels, out_channels=None): + super().__init__() + self.dims = dims + self.channels = channels + self.out_channels = out_channels or channels + self.conv = make_conv_nd(dims, channels, out_channels, kernel_size=3, padding=1, bias=True) + + def forward(self, x, upsample_in_time): + if self.dims == 2: + x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest") + else: + time_scale_factor = 2 if upsample_in_time else 1 + # print("before:", x.shape) + b, c, d, h, w = x.shape + x = rearrange(x, "b c d h w -> (b d) c h w") + # height and width interpolate + x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest") + _, _, h, w = x.shape + + if not upsample_in_time and self.dims == (2, 1): + x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) + return self.conv(x, skip_time_conv=True) + + # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) + + # (b h w) c 1 d + new_d = x.shape[-1] * time_scale_factor + x = functional.interpolate(x, (1, new_d), mode="nearest") + # (b h w) c 1 new_d + x = rearrange(x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d) + # b c d h w + + # x = functional.interpolate( + # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + # ) + # print("after:", x.shape) + + return self.conv(x) + + +def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding): + channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] + padding_zeros = torch.zeros( + x.shape[0], + channels_to_pad, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([padding_zeros, x], dim=1) + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding): + channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) + x = x[:, :channels_to_keep, :, :, :] + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [ + 128, + 256, + 512, + 512, + ], # Number of output channels of each encoder / decoder inner block + "patch_size": 1, + } + + return config + + +def create_video_autoencoder_pathify4x4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "latent_log_var": "uniform", + } + + return config + + +def create_video_autoencoder_pathify4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "norm_layer": "pixel_norm", + } + + return config + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_pathify4x4x4_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = VideoAutoencoder.from_config(config) + + print(video_autoencoder) + + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 8, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 000000000..ee7221652 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,86 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 000000000..247e9da1f --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,123 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 000000000..8f6e43dc0 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,125 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +import jax.numpy as jnp +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: + data_input, index_input = carry + + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + + output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers + + next_index = index_input + 1 + new_carry = (output_data, next_index) + + return new_carry, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + if not args: + raise ValueError("RepeatableLayer expects at least one argument for initial data input.") + + initial_data_input = args[0] + static_block_args = args[1:] + + initial_index = jnp.array(0, dtype=jnp.int32) # index of current transformer block + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + + in_axes_for_scan = (nn.broadcast,) * (len(args) - 1) + + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True}, + in_axes=in_axes_for_scan, + length=self.num_layers, + **scan_kwargs, + ) + + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + + # Call wrapped_function with the initial carry tuple and the static_block_args + (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) + + return final_data diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 000000000..4ae1d9a00 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,189 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + # Mish is not in JAX by default + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + # if len(args) > 0 or kwargs.get("scale", None) is not None: + # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + # deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 000000000..1078f0848 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,208 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 000000000..6fad32d8e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,912 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from functools import partial +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + index: int, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + skip_layer_strategy = SkipLayerStrategy.AttentionValues + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + + attn_output = self.attn1( + norm_hidden_states, + block_index=index, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + block_index=-1, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + block_index: int = -1, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask[block_index], (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionValues: + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + qkvo_sharding_spec = jax.sharding.PartitionSpec( + "data", + "fsdp", + None, + "tensor", + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 000000000..d8240989c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,56 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py new file mode 100644 index 000000000..d53b4d7ca --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py @@ -0,0 +1,98 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords(self, latent_num_frames, latent_height, latent_width, batch_size, device): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange(latent_coords, "b c f h w -> b c (f h w)", b=batch_size) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 000000000..8b12b1d81 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,326 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( # noqa: C408 + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + key, + **example_inputs, + )["params"] + else: + return self.init(key, **example_inputs)["params"] + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ) -> Optional[jnp.ndarray]: + if skip_block_list is None or len(skip_block_list) == 0: + return None + mask = jnp.ones((self.num_layers, batch_size * num_conds), dtype=self.dtype) + + for block_idx in skip_block_list: + mask = mask.at[block_idx, ptb_index::num_conds].set(0) + + return mask + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + skip_layer_mask=None, + skip_layer_strategy=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + skip_layer_mask, + skip_layer_strategy, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py new file mode 100644 index 000000000..285b6e81c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py new file mode 100644 index 000000000..a598114ad --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -0,0 +1,1137 @@ +import inspect +from importlib import import_module +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention import _chunked_feed_forward +from diffusers.models.attention_processor import ( + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SpatialNorm, +) +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import RMSNorm +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +try: + from torch_xla.experimental.custom_kernel import flash_attention +except ImportError: + # workaround for automatic tests. Currently this function is manually patched + # to the torch_xla lib on setup of container + pass + +# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): + The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): + The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_eps: float = 1e-5, + qk_norm: Optional[str] = None, + final_dropout: bool = False, + attention_type: str = "default", # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_tpu_flash_attention = use_tpu_flash_attention + self.adaptive_norm = adaptive_norm + + assert standardization_norm in ["layer_norm", "rms_norm"] + assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] + + make_norm_layer = nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = make_norm_layer(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=(cross_attention_dim if not double_self_attention else None), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) # is self-attn if encoder_hidden_states is none + + if adaptive_norm == "none": + self.attn2_norm = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 5. Scale-shift for PixArt-Alpha. + if adaptive_norm != "none": + num_ada_params = 4 if adaptive_norm == "single_scale" else 6 + self.scale_shift_table = nn.Parameter(torch.randn(num_ada_params, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + self.use_tpu_flash_attention = True + self.attn1.set_use_tpu_flash_attention() + self.attn2.set_use_tpu_flash_attention() + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + original_hidden_states = hidden_states + + norm_hidden_states = self.norm1(hidden_states) + + # Apply ada_norm_single + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1) + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + norm_hidden_states = norm_hidden_states.squeeze(1) # TODO: Check if this is needed + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=(encoder_hidden_states if self.only_cross_attention else None), + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.adaptive_norm == "none": + attn_input = self.attn2_norm(hidden_states) + else: + attn_input = hidden_states + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.TransformerBlock: + skip_layer_mask = skip_layer_mask.view(-1, 1, 1) + hidden_states = hidden_states * skip_layer_mask + original_hidden_states * (1.0 - skip_layer_mask) + + return hidden_states + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + qk_norm: Optional[str] = None, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.use_tpu_flash_attention = use_tpu_flash_attention + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + if qk_norm is None: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) + self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) + elif qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + else: + raise ValueError(f"Unsupported qk_norm method: {qk_norm}") + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError(f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'") + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. + """ + self.use_tpu_flash_attention = True + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": # noqa: F821 + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None for name, module in self.named_modules() if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError(f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}") + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by" + f" {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @staticmethod + def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.q_norm(query) + + if encoder_hidden_states is not None: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + key = attn.k_norm(key) + else: # if no context provided do self-attention + encoder_hidden_states = hidden_states + key = attn.to_k(hidden_states) + key = attn.k_norm(key) + if attn.use_rope: + key = attn.apply_rotary_emb(key, freqs_cis) + query = attn.apply_rotary_emb(query, freqs_cis) + + value = attn.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + + if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' + q_segment_indexes = None + if attention_mask is not None: # if mask is required need to tune both segmenIds fields + # attention_mask = torch.squeeze(attention_mask).to(torch.float32) + attention_mask = attention_mask.to(torch.float32) + q_segment_indexes = torch.ones(batch_size, query.shape[2], device=query.device, dtype=torch.float32) + assert ( + attention_mask.shape[1] == key.shape[2] + ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" + + assert query.shape[2] % 128 == 0, f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" + assert key.shape[2] % 128 == 0, f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" + + # run the TPU kernel implemented in jax with pallas + hidden_states_a = flash_attention( + q=query, + k=key, + v=value, + q_segment_ids=q_segment_indexes, + kv_segment_ids=attention_mask, + sm_scale=attn.scale, + ) + else: + hidden_states_a = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states_a = hidden_states_a.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_a = hidden_states_a.to(query.dtype) + + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionValues: + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) + + if attn.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py new file mode 100644 index 000000000..b4ca6b52f --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py @@ -0,0 +1,141 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py +import math + +import numpy as np +import torch +from einops import rearrange +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) + grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) + grid = grid.reshape([3, 1, w, h, f]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = pos_embed.transpose(1, 0, 2, 3) + return rearrange(pos_embed, "h w f c -> (f h w) c") + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # use half of dimensions to encode grid_h + emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) + + emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos_shape = pos.shape + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = out.reshape([*pos_shape, -1])[0] + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) + return emb + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py new file mode 100644 index 000000000..d53b4d7ca --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py @@ -0,0 +1,98 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords(self, latent_num_frames, latent_height, latent_width, batch_size, device): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange(latent_coords, "b c f h w -> b c (f h w)", b=batch_size) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py new file mode 100644 index 000000000..2ade88b86 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py @@ -0,0 +1,472 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union +import os +import json +import glob +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils import logging +from torch import nn +from safetensors import safe_open + + +from ltx_video.models.transformers.attention import BasicTransformerBlock +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + TRANSFORMER_KEYS_RENAME_DICT, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None, + positional_embedding_type: str = "rope", + positional_embedding_theta: Optional[float] = None, + positional_embedding_max_pos: Optional[List[int]] = None, + timestep_scale_multiplier: Optional[float] = None, + causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated + ): + super().__init__() + self.use_tpu_flash_attention = use_tpu_flash_attention # FIXME: push config down to the attention modules + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) + self.positional_embedding_type = positional_embedding_type + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.use_rope = self.positional_embedding_type == "rope" + self.timestep_scale_multiplier = timestep_scale_multiplier + + if self.positional_embedding_type == "absolute": + raise ValueError("Absolute positional embedding is no longer supported") + elif self.positional_embedding_type == "rope": + if positional_embedding_theta is None: + raise ValueError("If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined") + if positional_embedding_max_pos is None: + raise ValueError("If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined") + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + adaptive_norm=adaptive_norm, + standardization_norm=standardization_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=self.use_rope, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + if adaptive_norm == "single_scale": + self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") + self.use_tpu_flash_attention = True + # push config down to the attention modules + for block in self.transformer_blocks: + block.set_use_tpu_flash_attention() + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ): + if skip_block_list is None or len(skip_block_list) == 0: + return None + num_layers = len(self.transformer_blocks) + mask = torch.ones((num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + dim=-1, + ) + return fractional_positions + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dtype = torch.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + device = fractional_positions.device + if spacing == "exp": + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) + elif spacing == "sqrt": + indices = torch.linspace(start**2, end**2, dim // 6, device=device, dtype=dtype).sqrt() + + indices = indices * math.pi / 2 + + if spacing == "exp_2": + freqs = (indices * fractional_positions.unsqueeze(-1)).transpose(-1, -2).flatten(2) + else: + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): # noqa: C419 + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, *args, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = pretrained_model_path / "transformer" / "diffusion_pytorch_model*.safetensors" + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + with torch.device("meta"): + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, assign=True, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(".safetensors"): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + with torch.device("meta"): + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict, assign=True) + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + indices_grid: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # for tpu attention offload 2d token masks are used. No need to transform. + if not self.use_tpu_flash_attention: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + hidden_states = self.patchify_proj(hidden_states) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + freqs_cis = self.precompute_freqs_cis(indices_grid) + + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block_idx, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + freqs_cis, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + (skip_layer_mask[block_idx] if skip_layer_mask is not None else None), + skip_layer_strategy, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + skip_layer_mask=(skip_layer_mask[block_idx] if skip_layer_mask is not None else None), + skip_layer_strategy=skip_layer_strategy, + ) + + # 3. Output + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if not return_dict: + return (hidden_states,) + + return Transformer3DModelOutput(sample=hidden_states) diff --git a/src/maxdiffusion/models/ltx_video/utils/__init__.py b/src/maxdiffusion/models/ltx_video/utils/__init__.py new file mode 100644 index 000000000..285b6e81c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py new file mode 100644 index 000000000..bd4d115f9 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -0,0 +1,259 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import argparse +import json +from typing import Any, Dict, Optional + + +import jax +import jax.numpy as jnp +from flax.training import train_state +import optax +import orbax.checkpoint as ocp +from safetensors.torch import load_file + +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel +from maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d import Transformer3DModel as Transformer3DModel +from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax + +from huggingface_hub import hf_hub_download +import os + + +class Checkpointer: + """ + Checkpointer - to load and store JAX checkpoints + """ + + STATE_DICT_SHAPE_KEY = "shape" + STATE_DICT_DTYPE_KEY = "dtype" + TRAIN_STATE_FILE_NAME = "train_state" + + def __init__( + self, + checkpoint_dir: str, + use_zarr3: bool = False, + save_buffer_size: Optional[int] = None, + restore_buffer_size: Optional[int] = None, + ): + """ + Constructs the checkpointer object + """ + opts = ocp.CheckpointManagerOptions( + enable_async_checkpointing=True, + step_format_fixed_length=8, # to make the format of "00000000" + ) + self.use_zarr3 = use_zarr3 + self.save_buffer_size = save_buffer_size + self.restore_buffer_size = restore_buffer_size + registry = ocp.DefaultCheckpointHandlerRegistry() + self.train_state_handler = ocp.PyTreeCheckpointHandler( + save_concurrent_gb=save_buffer_size, + restore_concurrent_gb=restore_buffer_size, + use_zarr3=use_zarr3, + ) + registry.add( + self.TRAIN_STATE_FILE_NAME, + ocp.args.PyTreeSave, + self.train_state_handler, + ) + self.manager = ocp.CheckpointManager( + directory=checkpoint_dir, + options=opts, + handler_registry=registry, + ) + + @property + def save_buffer_size_bytes(self) -> Optional[int]: + if self.save_buffer_size is None: + return None + return self.save_buffer_size * 2**30 + + @staticmethod + def state_dict_to_structure_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts a state dict to a dictionary stating the shape and dtype of the state_dict elements. + With this, we can reconstruct the state_dict structure later on. + """ + return jax.tree_util.tree_map( + lambda t: { + Checkpointer.STATE_DICT_SHAPE_KEY: tuple(t.shape), + Checkpointer.STATE_DICT_DTYPE_KEY: t.dtype.name, + }, + state_dict, + is_leaf=lambda t: isinstance(t, jax.Array), + ) + + def save( + self, + step: int, + state: train_state.TrainState, + config: Dict[str, Any], + ): + """ + Saves the checkpoint asynchronously + + NOTE that state is going to be copied for this operation + + Args: + step (int): The step of the checkpoint + state (TrainStateWithEma): A trainstate containing both the parameters and the optimizer state + config (Dict[str, Any]): A dictionary containing the configuration of the model + """ + self.wait() + args = ocp.args.Composite( + train_state=ocp.args.PyTreeSave( + state, + ocdbt_target_data_file_size=self.save_buffer_size_bytes, + ), + config=ocp.args.JsonSave(config), + meta_params=ocp.args.JsonSave(self.state_dict_to_structure_dict(state.params)), + ) + self.manager.save( + step, + args=args, + ) + + def wait(self): + """ + Waits for the checkpoint save operation to complete + """ + self.manager.wait_until_finished() + + +""" +Convert Torch checkpoints to JAX. + +This script loads a Torch checkpoint (either regular or sharded), converts it to Jax weights, and saved it. +""" + + +def main(args): + """ + Convert a Torch checkpoint into JAX. + """ + + if args.output_step_num > 1: + print( + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ) + + print("Loading safetensors, flush = True") + weight_file = "ltxv-13b-0.9.7-dev.safetensors" + + # download from huggingface, otherwise load from local + + print("Loading from HF", flush=True) + model_name = "Lightricks/LTX-Video" + absolute_ckpt_path = os.path.abspath(args.ckpt_path) + local_file_path = hf_hub_download( + repo_id=model_name, + filename=weight_file, + local_dir=absolute_ckpt_path, + local_dir_use_symlinks=False, + ) + torch_state_dict = load_file(local_file_path) + + print("Initializing pytorch transformer..", flush=True) + transformer_config = json.loads(open(args.transformer_config_path, "r").read()) + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "ckpt_path"] + for key in ignored_keys: + if key in transformer_config: + del transformer_config[key] + + transformer = Transformer3DModel.from_config(transformer_config) + + print("Loading torch weights into transformer..", flush=True) + transformer.load_state_dict(torch_state_dict) + torch_state_dict = transformer.state_dict() + + print("Creating jax transformer with params..", flush=True) + transformer_config["use_tpu_flash_attention"] = True + in_channels = transformer_config["in_channels"] + del transformer_config["in_channels"] + jax_transformer3d = JaxTranformer3DModel( + **transformer_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch" + ) + example_inputs = {} + batch_size, num_tokens = 2, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, transformer_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + params_jax = jax_transformer3d.init(jax.random.PRNGKey(42), **example_inputs) + + print("Converting torch params to jax..", flush=True) + params_jax = torch_statedict_to_jax(params_jax, torch_state_dict) + + print("Creating checkpointer and jax state for saving..", flush=True) + relative_ckpt_path = os.path.join(args.ckpt_path, "jax_weights") + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + tx = optax.adamw(learning_rate=1e-5) + with jax.default_device("cpu"): + state = train_state.TrainState( + step=args.output_step_num, + apply_fn=jax_transformer3d.apply, + params=params_jax, + tx=tx, + opt_state=tx.init(params_jax), + ) + with ocp.CheckpointManager(absolute_ckpt_path) as mngr: + mngr.save(args.output_step_num, args=ocp.args.StandardSave(state.params)) + print("Done.", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.") + parser.add_argument( + "--ckpt_path", + type=str, + required=False, + help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'", + ) + + parser.add_argument( + "--output_step_num", + default=1, + type=int, + required=False, + help=( + "The step number to assign to the output checkpoint. The result will be saved using this step value. " + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ), + ) + parser.add_argument( + "--transformer_config_path", + default="/opt/txt2img/txt2img/config/transformer3d/ltxv2B-v1.0.json", + type=str, + required=False, + help="Path to Transformer3D structure config to load the weights based on.", + ) + + args = parser.parse_args() + main(args) diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 000000000..81094d676 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,190 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py new file mode 100644 index 000000000..1d404be39 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py @@ -0,0 +1,206 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import logging +from typing import Union, List, Optional + +import torch +from PIL import Image + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Do not change the user input intent, just enhance it. +Keep within 150 words. +For best results, build your prompts using this structure: +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + +I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Keep within 150 words. +For best results, build your prompts using this structure: +Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Align to the image caption if it contradicts the user text input. +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + + +def tensor_to_pil(tensor): + # Ensure tensor is in range [-1, 1] + assert tensor.min() >= -1 and tensor.max() <= 1 + + # Convert from [-1, 1] to [0, 1] + tensor = (tensor + 1) / 2 + + # Rearrange from [C, H, W] to [H, W, C] + tensor = tensor.permute(1, 2, 0) + + # Convert to numpy array and then to uint8 range [0, 255] + numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") + + # Convert to PIL Image + return Image.fromarray(numpy_image) + + +def generate_cinematic_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompt: Union[str, List[str]], + conditioning_items: Optional[List] = None, + max_new_tokens: int = 256, +) -> List[str]: + prompts = [prompt] if isinstance(prompt, str) else prompt + + if conditioning_items is None: + prompts = _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + max_new_tokens, + T2V_CINEMATIC_PROMPT, + ) + else: + if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: + logger.warning( + "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" + ) + return prompts + + first_frame_conditioning_item = conditioning_items[0] + first_frames = _get_first_frames_from_conditioning_item(first_frame_conditioning_item) + + assert len(first_frames) == len(prompts), "Number of conditioning frames must match number of prompts" + + prompts = _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + first_frames, + max_new_tokens, + I2V_CINEMATIC_PROMPT, + ) + + return prompts + + +def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: + frames_tensor = conditioning_item.media_item + return [tensor_to_pil(frames_tensor[i, :, 0, :, :]) for i in range(frames_tensor.shape[0])] + + +def _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}"}, + ] + for p in prompts + ] + + texts = [prompt_enhancer_tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(prompt_enhancer_model.device) + + return _generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens) + + +def _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + first_frames: List[Image.Image], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + image_captions = _generate_image_captions(image_caption_model, image_caption_processor, first_frames) + + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, + ] + for p, c in zip(prompts, image_captions) + ] + + texts = [prompt_enhancer_tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(prompt_enhancer_model.device) + + return _generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens) + + +def _generate_image_captions( + image_caption_model, + image_caption_processor, + images: List[Image.Image], + system_prompt: str = "", +) -> List[str]: + image_caption_prompts = [system_prompt] * len(images) + inputs = image_caption_processor(image_caption_prompts, images, return_tensors="pt").to(image_caption_model.device) + + with torch.inference_mode(): + generated_ids = image_caption_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) + + +def _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int +) -> List[str]: + with torch.inference_mode(): + outputs = prompt_enhancer_model.generate(**model_inputs, max_new_tokens=max_new_tokens) + generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, outputs)] + decoded_prompts = prompt_enhancer_tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + return decoded_prompts diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 000000000..74e74c1c6 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,24 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py new file mode 100644 index 000000000..6cbfea70b --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py @@ -0,0 +1,536 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import re +from copy import copy +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union, Any + +import flax +import jax +import torch +import torch.utils._pytree as pytree +from flax.traverse_util import flatten_dict + + +AnyTensor = Union[jax.Array, torch.Tensor] +StateDict = Dict[str, AnyTensor] + +ScanRepeatableCarryBlock = "ScanRepeatableCarryBlock" + +JaxParams = Dict[str, Union[Dict[str, jax.Array], jax.Array]] + + +def unbox_logically_partioned(statedict: JaxParams) -> JaxParams: + return jax.tree_util.tree_map( + lambda t: t.unbox() if isinstance(t, flax.linen.spmd.LogicallyPartitioned) else t, + statedict, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + + +def torch_tensor_to_jax_array(data: torch.Tensor) -> jax.Array: + match data.dtype: + case torch.bfloat16: + return jax.numpy.from_dlpack(data) + case _: + return jax.numpy.array(data) + + +def is_stack_or_tensor(param: Any) -> bool: + """ + Returns True if param is of type tensor or list/tuple of tensors (stack of tensors) + + Used for mapping utils + """ + return isinstance(param, (torch.Tensor, list, tuple)) + + +def convert_tensor_stack_to_tensor(param: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + """ + Converts a list of torch tensors to a single torch tensor. + Args: + param (Union[List[torch.Tensor], torch.Tensor]): The parameter to convert. + + Returns: + torch.Tensor: The converted tensor. + """ + if isinstance(param, list): + return torch.stack(param) + return param + + +@dataclass +class ConvertAction: + """ + Defines a set of actions to be done on a given parameter. + + The definition must be commutative, i.e. the order of the actions should not matter. + also we should strive for actions to be reversible (so the same action can be used for both directions). + """ + + transpose: Optional[Tuple[int, int]] = None + """ + If defined, transposes the tensor with the given indices. + Example: (1, 0) transposes a (at least 2D tensor) from (..., a, b) to (..., b, a). + """ + + rename: Optional[Dict[str, str]] = None + """ + If defined, renames the parameter according to the given mapping. + Example: {"torch": "weight", "jax": "kernel"} + * renames "torch.weight" to "jax.kernel" when converting from torch to jax. + * renames "jax.kernel" to "torch.weight" when converting from jax to torch. + """ + + split_by: Optional[str] = None + """ + If defined, splits the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.k1" assumes the parameter is a concatenation of multiple tensors (shaped: (n, ...)). + and splits them into individual tensors named as "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.n.k1". + """ + + group_by: Optional[str] = None + """ + If defined, groups the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.1.k1", "ScanRepeatableCarryBlock.2.k1" + will be grouped into a single tensor named "ScanRepeatableCarryBlock.k1" shaped (n, ...). + + *** Note: + this is kind of the reverse of split_by, only a different behavior. + it's easy to define "actions" that are reversible in base of context (jax->torch, torch->jax). + but it's very wrong to do so, since it blocks modular behavior and makes the code harder to maintain. + + """ + + jax_groups: Optional[List[str]] = None + """ + Generally used in group_by, this is a list of all possible keys that can be used to group the parameters. + This must be defined if group_by is defined. + + It's due to the un-reversibility nature of the group_by action. + """ + + def apply_transpose(self, mini_statedict: StateDict) -> StateDict: + """ + Applies the transpose action if defined + Args: + mini_statedict (StateDict): Local context of the state dict + + Returns: + StateDict: Output local context of the state dict + """ + + if self.transpose is None: + return mini_statedict + index0, index1 = self.transpose + return {param_name: param.swapaxes(index0, index1) for param_name, param in mini_statedict.items()} + + def apply_rename(self, mini_statedict: StateDict, delim: str) -> StateDict: + """ + Applies the rename action if defined + + Args: + mini_statedict (StateDict): Local context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + StateDict: Output local context of the state dict + """ + if self.rename is None: + return mini_statedict + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + parts = param_name.split(delim) + rename_source = "torch" if isinstance(param, torch.Tensor) else "jax" + rename_target = "jax" if isinstance(param, torch.Tensor) else "torch" + source_name = self.rename[rename_source] + dest_name = self.rename[rename_target] + if source_name == param_name: + new_param_name = dest_name + else: + # There is always ```self.rename[rename_source]``` in parts + index = parts.index(self.rename[rename_source]) + parts[index] = self.rename[rename_target] + new_param_name = delim.join(parts) + mini_statedict[new_param_name] = param + + return mini_statedict + + def apply_split_by(self, mini_statedict: StateDict, new_params: List, delim: str) -> Tuple[StateDict, List[str]]: + """ + Applies the split_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + delim (str): Output local context of the state dict + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.split_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + parts = param_name.split(delim) + indices = [i for i, p in enumerate(parts) if self.split_by in p] + if len(indices) != 1: + raise ValueError(f"Expected exactly one split_by in param_name: {param_name}") + index = indices[0] + params = mini_statedict.pop(param_name) + for i, param in enumerate(params): + new_parts = parts[:index] + [f"{i}"] + parts[index + 2 :] + new_param_name = delim.join(new_parts) + mini_statedict[new_param_name] = param + new_params.append(new_param_name) + + return mini_statedict, new_params + + def apply_group_by( + self, mini_statedict: StateDict, new_params: List, full_statedict: StateDict, delim: str + ) -> Tuple[StateDict, List[str]]: + """ + Applies the group_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.group_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + jax_keywords = extract_scan_keywords(param_name, self.jax_groups, delim) + block_index = re.findall(r"\.\d+\.", param_name)[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + new_param_name = f"{prefix}.{delim.join(jax_keywords)}.{suffix}" + + if new_param_name not in full_statedict: + full_statedict[new_param_name] = [param] + else: + full_statedict[new_param_name] = full_statedict[new_param_name] + [param] + + return mini_statedict, new_params + + def __call__( + self, + mini_statedict: StateDict, + new_params: List, + full_statedict: StateDict, + delim: str, + ) -> Tuple[StateDict, List[str]]: + """ + Given a state dict, applies the transformations defined in the ConvertAction. + + Args: + mini_statedict (StateDict): Local context of the state dict + new_params (List): new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Updated local state dict and list of new keys to add to the global state dict. + """ + mini_statedict = self.apply_transpose(mini_statedict) + mini_statedict = self.apply_rename(mini_statedict, delim) + mini_statedict, new_params = self.apply_split_by(mini_statedict, new_params, delim) + mini_statedict, new_params = self.apply_group_by(mini_statedict, new_params, full_statedict, delim) + return mini_statedict, new_params + + +def is_kernel_2d(param_name: str, param: AnyTensor) -> bool: + """ + Checks if the parameter is a 2D kernel (weight) or not. + usually applies to linear layers or convolutions. + Args: + param_name (str): Name of the parameter + param (AnyTensor): The parameter itself (could be either jax or torch Tensor) + + Returns: + bool: True if the parameter is a weight for linear/convolutional layer or not. + """ + expected_name = "weight" if isinstance(param, torch.Tensor) else "kernel" + return expected_name in param_name and param.ndim == 2 + + +def is_scan_repeatable(param_name: str, _) -> bool: + """ + Checks if the parameter is a scan repeatable carry block parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scan repeatable carry block parameter or not. + """ + return ScanRepeatableCarryBlock in param_name + + +def is_scale_shift_table(param_name: str, _) -> bool: + """ + Checks if the parameter is a scale shift table parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scale shift table parameter or not. + """ + return "scale_shift_table" in param_name + + +def is_affine_scale_param(param_name: str, parameter: AnyTensor, jax_flattened_keys: List[str]) -> bool: + """ + Checks if the parameter is an affine scale parameter. + + Args: + param_name (str): Parameter name + parameter (AnyTensor): The parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + + + Returns: + bool: True if the parameter is an affine scale parameter or not. + """ + if isinstance(parameter, torch.Tensor): + return "weight" in param_name and parameter.ndim == 1 and param_name not in jax_flattened_keys + else: + return "scale" in param_name and parameter.ndim == 1 + + +def extract_scan_keywords(param_name: str, jax_flattened_keys: List[str], delim: str) -> Optional[Tuple[str, str]]: + """ + Extracts the keywords from the scan repeatable carry block parameter (if exists) + + If the parameter is a scan repeatable carry block, it will return the keywords that are used to group the parameters. + otherwise it will return None. + + Args: + param_name (str): Name of the parameter + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + Optional[Tuple[str, str]]: Tuple of the keywords used to group the parameters (or None if it is not a scan repeatable carry block) + """ + block_indices = re.findall(r"\.\d+\.", param_name) + + if len(block_indices) == 0: + return None + block_indices = [block_indices[0]] + block_index = block_indices[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + for flat_key in jax_flattened_keys: + if flat_key.startswith(prefix) and flat_key.endswith(suffix): + mid_layer = flat_key[len(prefix) + 1 : -len(suffix) - 1] + mid_parts = mid_layer.split(delim) + if not any(ScanRepeatableCarryBlock in mid_part for mid_part in mid_parts): + continue + return mid_parts + + return None + + +def should_be_scan_repeatable(param_name: str, param: AnyTensor, jax_flattened_keys: List[str], delim: str) -> bool: + """ + Checks if the parameter should be a scan repeatable carry block or not. + Args: + param_name (str): The name of the parameter + param (AnyTensor): the Parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + bool: True if the paramter should be treated scan repeatable block parameter. + """ + if not isinstance(param, torch.Tensor): + return False + + keywords = extract_scan_keywords(param_name, jax_flattened_keys, delim) + return keywords is not None + + +def jax_statedict_to_torch( + jax_params: JaxParams, rulebook: Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]] = None +) -> Dict[str, torch.Tensor]: + """ + Converts a JAX state dict to a torch state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + rulebook (Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]], optional): Defines a rulebook stating how to convert state dict from jax to torch. + Defaults to None. + + + Returns: + Dict[str, torch.Tensor]: The converted state dict in torch format (Pytorch state dict). + """ + + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=[]) + + if rulebook is None: + rulebook = { + is_scan_repeatable: ConvertAction(split_by=ScanRepeatableCarryBlock), + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), # noqa C408 + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), # noqa C408 + } + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + jax_params = copy(jax_params["params"]) # Non reference copy + jax_params = unbox_logically_partioned(jax_params) + + delim = "." + # Move to flattened dict to match torch state dict convention + flattened_params = flatten_dict(jax_params, sep=delim) + + param_names = list(flattened_params.keys()) + for param_name in param_names: + param = flattened_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, flattened_params, delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + flattened_params.update(mini_statedict) + param_names.extend(new_params) + + flattened_params = pytree.tree_map(convert_tensor_stack_to_tensor, flattened_params, is_leaf=is_stack_or_tensor) + + to_cpu = pytree.tree_map(lambda t: jax.device_put(t, jax.devices("cpu")[0]), flattened_params) + to_torch = pytree.tree_map(torch.from_dlpack, to_cpu) + return to_torch + + +def torch_statedict_to_jax( + jax_params: JaxParams, + torch_params: Dict[str, torch.Tensor], +) -> JaxParams: + """ + Converts a torch state dict to a JAX state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + torch_params (Dict[str, torch.Tensor]): The current params in torch format, to load parameters from. + + Returns: + JaxParams: The state dict in JAX format. + """ + with jax.default_device("cpu"): + jax_params = copy(jax_params) + jax_params = unbox_logically_partioned(jax_params) + torch_params = copy(torch_params) + + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + delim = "." + flattened_keys = list(flatten_dict(jax_params["params"], sep=".").keys()) + scan_repeatable_cond = partial(should_be_scan_repeatable, jax_flattened_keys=flattened_keys, delim=delim) + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=flattened_keys) + + rulebook = { + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), # noqa C408 + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), # noqa C408 + scan_repeatable_cond: ConvertAction(group_by=ScanRepeatableCarryBlock, jax_groups=flattened_keys), + } + + # First pass - Rulebook + param_names = list(torch_params.keys()) + for param_name in param_names: + param = torch_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, torch_params, delim=delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + torch_params.update(mini_statedict) + param_names.extend(new_params) + + # Ensures any list of tensors are converted to a single tensor + # This is due to the fact that the scan repeatable block is a list of tensors + torch_params = pytree.tree_map(convert_tensor_stack_to_tensor, torch_params, is_leaf=is_stack_or_tensor) + + to_jax: Dict = pytree.tree_map(torch_tensor_to_jax_array, torch_params) + + def nested_insert(param_name: str, param: torch.Tensor, nested_dict: Dict): + """ + Inserts a parameter into a nested dictionary. (to fit Jax format) + The keys in torch are split into groups by a delimiter of choice (usually "." to fit torch schema) + and then inserted into a nested dictionary. + + in case the parameter is of the form of "a.b" and "a.b" is a layer type in jax - + the parameter will be inserted as "a.b": {...: param}. this ensures compatibility between jax layers and torch layers. + + Args: + param_name (str): Parameter name + param (torch.Tensor): Parameter itself + nested_dict (Dict): Current nested dict state + """ + if delim not in param_name: + nested_dict[param_name] = param + return + + parts = param_name.split(delim) + if len(parts) == 1: + return nested_insert(parts[0], param, nested_dict) + else: + key = parts[0] + # May be either complex key or nested key + if len(parts) > 2 and re.fullmatch(r"\d+", parts[1]) is not None: + key = delim.join(parts[:2]) + new_param_name = delim.join(parts[2:]) + else: + new_param_name = delim.join(parts[1:]) + new_nested_dict = nested_dict.get(key, {}) + nested_dict[key] = new_nested_dict + return nested_insert(new_param_name, param, new_nested_dict) + + params = {} + for param_name, param in to_jax.items(): + nested_insert(param_name, param, params) + + # Jax state dict is usually held as dict containings "parmas" keys which contains + # dict of dict containing all the params + return {"params": params} diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py new file mode 100644 index 000000000..6dca31b1f --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py @@ -0,0 +1,39 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import torch +from torch import nn + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Identity(nn.Module): + """A placeholder identity operator that is argument-insensitive.""" + + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() + + # pylint: disable=unused-argument + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 000000000..75b16b011 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,25 @@ +{ + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, + "in_channels": 128 +} \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/ltx_video/__init__.py b/src/maxdiffusion/pipelines/ltx_video/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/maxdiffusion/pipelines/ltx_video/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py new file mode 100644 index 000000000..0ca816f9e --- /dev/null +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -0,0 +1,1039 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +from jax import Array +from typing import Optional, List, Union, Tuple +from einops import rearrange +import torch.nn.functional as F +from maxdiffusion.models.ltx_video.autoencoders.vae_torchax import TorchaxCausalVideoAutoencoder +from transformers import (FlaxT5EncoderModel, AutoTokenizer) +from torchax import interop +from torchax import default_env +import json +import numpy as np +import torch +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, +) +from maxdiffusion import max_logging +from huggingface_hub import hf_hub_download +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.autoencoders.vae_encode import ( + get_vae_size_scale_factor, + latent_to_pixel_coords, + un_normalize_latents, + normalize_latents, +) +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt +from types import NoneType +from typing import Any, Dict + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from ...pyconfig import HyperParameters +from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState +from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +import orbax.checkpoint as ocp + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encoder_attention_segment_ids): + # Note: reference shape annotated for first pass default inference parameters + max_logging.log("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) # (3, 256, 4096) float32 + max_logging.log("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) # (3, 3, 3072) float32 + max_logging.log("latents.shape: ", latents.shape, latents.dtype) # (1, 3072, 128) float 32 + max_logging.log( + "encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype + ) # (3, 256) int32 + + +class LTXVideoPipeline: + + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: TorchaxCausalVideoAutoencoder, + text_encoder, + patchifier, + tokenizer, + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, + ): + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(self.vae) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + with open(config.config_path, "r") as f: + model_config = json.load(f) + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + key = jax.random.PRNGKey(config.seed) + key, subkey = jax.random.split(key) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, subkey, model_config["caption_channels"], eval_only=True + ) + # loading from weight checkpoint + models_dir = config.output_dir + jax_weights_dir = os.path.join(models_dir, "jax_weights") + checkpoint_manager = ocp.CheckpointManager(jax_weights_dir) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item="ltxvid_transformer", + model_params=None, + training=False, + ) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + torch_vae = CausalVideoAutoencoder.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16) + # in torchax + with default_env(): + torch_vae = torch_vae.to("jax") + jax_vae = TorchaxCausalVideoAutoencoder(torch_vae) + return jax_vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + t5_tokenizer = AutoTokenizer.from_pretrained(ckpt_path, max_length=config.max_sequence_length, use_fast=True) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer(config) + + models_dir = config.output_dir + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + + vae = cls.load_vae(ltxv_model_path) + text_encoder = cls.load_text_encoder(config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer(config, config.text_encoder_model_name_or_path) + + if enhance_prompt: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = cls.load_prompt_enhancement(config) + else: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = (None, None, None, None) + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings, + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( # currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, + latent_shape, + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps( + state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps + ) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps + >= num_inference_steps # Use the original num_inference_steps here for the check + ): + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps] + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape=latent_shape, state=scheduler_state) + + return scheduler_state + + def encode_prompt( # currently only supports passing in a prompt + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + + max_length = text_encoder_max_tokens # TPU supports only lengths multiple of 128 + + assert self.text_encoder is not None, "You should provide either prompt_embeds or self.text_encoder should not be None," + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile(prompt_attention_mask, (1, num_images_per_prompt)) + prompt_attention_mask = jnp.reshape(prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, (1, num_images_per_prompt, 1)) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, (1, num_images_per_prompt)) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + def prepare_latents( # currently no support for media item encoding, since encoder isn't tested + self, + latents: Optional[jnp.ndarray], + timestep: float, + latent_shape: Tuple[Any, ...], + dtype: jnp.dtype, + key: jax.random.PRNGKey, + ) -> jnp.ndarray: + """ + Prepares initial latents for a diffusion process, potentially encoding media items + or adding noise + """ + + if latents is not None: + assert latents.shape == latent_shape, f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate noise in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + + # Generate noise using jax.random.normal + noise_intermediate_shape = (b, f * h * w, c) + noise = jax.random.normal(key, noise_intermediate_shape, dtype=dtype) + + # Rearrange "b (f h w) c -> b c f h w" + # Step 1: Reshape to (b, f, h, w, c) + noise = noise.reshape(b, f, h, w, c) + # Step 2: Permute/Transpose to (b, c, f, h, w) + noise = jnp.transpose(noise, (0, 4, 1, 2, 3)) # Old (b,f,h,w,c) -> New (b,c,f,h,w) + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + timestep_array = jnp.array(timestep, dtype=dtype) + latents = timestep_array * noise + (1 - timestep_array) * latents + + return latents + + def prepare_conditioning( # no support for conditioning items, conditioning mask, needs to convert to torch before patchifier + self, + init_latents: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray, int]: + assert isinstance(self.vae, TorchaxCausalVideoAutoencoder) + init_latents = torch.from_numpy(np.array(init_latents)) + init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) + init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=True) + return ( + jnp.array(init_latents.to(torch.float32).detach().numpy()), + jnp.array(init_pixel_coords.to(torch.float32).detach().numpy()), + 0, + ) + + def denormalize(self, images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + r""" + Borrowed from diffusers.image_processor + Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The denormalized image array. + """ + return (images * 0.5 + 0.5).clamp(0, 1) + + def _denormalize_conditionally(self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None) -> torch.Tensor: + r""" + Borrowed from diffusers.image_processor + Denormalize a batch of images based on a condition list. + + Args: + images (`torch.Tensor`): + The input image tensor. + do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): + A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the + value of `do_normalize` in the `VaeImageProcessor` config. + """ + if do_denormalize is None: + return self.denormalize(images) + + return torch.stack([self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]) + + def postprocess_to_output_type(self, image, output_type): + """ + Borrowed from diffusers.image_processor + Currrently supporting output type latent and pt + """ + if not isinstance(image, torch.Tensor): + raise ValueError(f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor") + + if output_type not in ["latent", "pt", "np", "pil"]: + output_type = "np" + + if output_type == "latent": + return image + image = self._denormalize_conditionally(image, None) + if output_type == "pt": + return image + + def __call__( + self, + config, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + frame_rate: int = 30, + latents: Optional[jnp.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + seed: int = 0, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + key = jax.random.PRNGKey(seed) + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, TorchaxCausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + with open(config.config_path, "r") as f: + model_config = json.load(f) + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps( + self.scheduler, + latent_shape, + self.scheduler_state, + num_inference_steps, + None, + skip_initial_inference_steps, + skip_final_inference_steps, + ) + + # set up guidance + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [i for i, val in enumerate(guidance_timesteps) if val <= timestep] + guidance_mapping.append(indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + # prepare skip block list + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask(batch_size, num_conds, num_conds - 1, skip_blocks) + for skip_blocks in skip_block_list + ] + + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) + + # optionally pass in a latent here + latents = self.prepare_latents( + latents=latents, + timestep=scheduler_state.timesteps[0], + latent_shape=latent_shape, + dtype=jnp.float32, + key=key, + ) + + latents, pixel_coords, num_cond_latents = self.prepare_conditioning( + init_latents=latents, + ) + + pixel_coords = jnp.concatenate([pixel_coords] * num_conds, axis=0) + fractional_coords = pixel_coords.astype(jnp.float32) + fractional_coords = fractional_coords.at[:, 0].set(fractional_coords[:, 0] * (1.0 / frame_rate)) + validate_transformer_inputs(prompt_embeds_batch, fractional_coords, latents, prompt_attention_mask_batch) + + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling=do_rescaling, + rescaling_scale=rescaling_scale, + batch_size=batch_size, + skip_layer_masks=skip_layer_masks, + skip_layer_strategy=skip_layer_strategy, + cfg_star_rescale=cfg_star_rescale, + ) + + with self.mesh: + latents, scheduler_state = p_run_inference( + transformer_state=self.transformer_state, latents=latents, scheduler_state=scheduler_state + ) + + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size), + ) + + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = jax.random.normal(key, latents.shape, dtype=latents.dtype) + # Convert decode_timestep to a list if it's not already one + if not isinstance(decode_timestep, (list, jnp.ndarray)): + decode_timestep = [decode_timestep] * latents.shape[0] + + # Handle decode_noise_scale + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, (list, jnp.ndarray)): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32) + + # Reshape decode_noise_scale for broadcasting + decode_noise_scale = jnp.array(decode_noise_scale, dtype=jnp.float32) + decode_noise_scale = jnp.reshape(decode_noise_scale, (latents.shape[0],) + (1,) * (latents.ndim - 1)) + + # Apply the noise and scale + latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale + else: + decode_timestep = None + image = self.vae.decode( + latents=jax.device_put(latents, jax.devices("tpu")[0]), + is_video=is_video, + vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), + timestep=decode_timestep, + ) + # convert back to torch to postprocess using the diffusers library + image = self.postprocess_to_output_type( + torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type + ) + + else: + image = latents + + if not return_dict: + return (image,) + + return image + + +def transformer_forward_pass( + latents, + state, + noise_cond, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask, + skip_layer_strategy, +): + + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + ) + return noise_pred, state + + +def run_inference( + transformer_state, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + scheduler, + segment_ids, + encoder_attention_segment_ids, + scheduler_state, + do_classifier_free_guidance, + num_conds, + guidance_scale, + do_spatio_temporal_guidance, + stg_scale, + do_rescaling, + rescaling_scale, + batch_size, + skip_layer_masks, + skip_layer_strategy, + cfg_star_rescale, +): + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) + + # Broadcast to batch dimension + current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, + transformer_state, + current_timestep, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask=(skip_layer_masks[i] if skip_layer_masks is not None else None), + skip_layer_strategy=skip_layer_strategy, + ) + + # perform guidance on noise prediction + if do_spatio_temporal_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) + squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + alpha = dot_product / squared_norm + alpha = alpha.reshape(batch_size, 1, 1) + noise_pred_uncond = alpha * noise_pred_uncond + noise_pred = noise_pred_uncond + guidance_scale[i] * (noise_pred_text - noise_pred_uncond) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * (noise_pred_text - noise_pred_text_perturb) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state + + +def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, factor: float = 1.0) -> jnp.ndarray: + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor, implemented in JAX. + + Args: + latents (jax.Array): Input latents to normalize. Expected shape (B, C, F, H, W). + reference_latents (jax.Array): The reference latents providing style statistics. + Expected shape (B, C, F, H, W). + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + jax.Array: The transformed latent tensor. + """ + with default_env(): + latents = jax.device_put(latents, jax.devices("tpu")[0]) + reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0]) + + # Define the core AdaIN operation for a single (F, H, W) slice. + # This function will be vmapped over batch (B) and channel (C) dimensions. + def _adain_single_slice(latent_slice: jnp.ndarray, ref_latent_slice: jnp.ndarray) -> jnp.ndarray: + """ + Applies AdaIN to a single latent slice (F, H, W) based on a reference slice. + """ + r_mean = jnp.mean(ref_latent_slice) + r_sd = jnp.std(ref_latent_slice) + + # Calculate standard deviation and mean for the input latent slice + i_mean = jnp.mean(latent_slice) + i_sd = jnp.std(latent_slice) + i_sd_safe = jnp.where(i_sd < 1e-6, 1.0, i_sd) + normalized_latent = (latent_slice - i_mean) / i_sd_safe + transformed_latent_slice = normalized_latent * r_sd + r_mean + return transformed_latent_slice + + transformed_latents_core = jax.vmap( + jax.vmap(_adain_single_slice, in_axes=(0, 0)), in_axes=(0, 0) # Vmap over batch (axis 0) + )(latents, reference_latents) + result_blended = latents * (1.0 - factor) + transformed_latents_core * factor + + return result_blended + + +class LTXMultiScalePipeline: + + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir=config.output_dir, + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler + + def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: jnp.ndarray): + latents = jax.device_put(latents, jax.devices("tpu")[0]) + with default_env(): + latents = un_normalize_latents(interop.torch_view(latents), self.vae, vae_per_channel_normalize=True) + upsampled_latents = latest_upsampler(torch.from_numpy(np.array(latents))) # converted back to torch before upsampler + upsampled_latents = normalize_latents( + interop.torch_view(jnp.array(upsampled_latents.detach().numpy())), self.vae, vae_per_channel_normalize=True + ) + return upsampled_latents + + def __init__(self, video_pipeline: LTXVideoPipeline): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + + def __call__( + self, height, width, num_frames, is_video, output_type, config, seed: int = 0, enhance_prompt: bool = False + ) -> Any: + # first pass + original_output_type = output_type + output_type = "latent" + result = self.video_pipeline( + config=config, + height=height, + width=width, + enhance_prompt=enhance_prompt, + num_frames=num_frames, + is_video=is_video, + output_type=output_type, + seed=seed, + guidance_scale=config.first_pass["guidance_scale"], + stg_scale=config.first_pass["stg_scale"], + rescaling_scale=config.first_pass["rescaling_scale"], + skip_initial_inference_steps=config.first_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.first_pass["skip_final_inference_steps"], + num_inference_steps=config.first_pass["num_inference_steps"], + guidance_timesteps=config.first_pass["guidance_timesteps"], + cfg_star_rescale=config.first_pass["cfg_star_rescale"], + skip_layer_strategy=None, + skip_block_list=config.first_pass["skip_block_list"], + ) + latents = result + max_logging.log("first pass done") + latent_upsampler = self.load_latent_upsampler(config) + upsampled_latents = self._upsample_latents(latent_upsampler, latents) + upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) + + # second pass + latents = upsampled_latents + result = self.video_pipeline( + config=config, + height=height * 2, + width=width * 2, + enhance_prompt=enhance_prompt, + num_frames=num_frames, + is_video=is_video, + seed=seed, + output_type=original_output_type, + latents=latents, + guidance_scale=config.second_pass["guidance_scale"], + stg_scale=config.second_pass["stg_scale"], + rescaling_scale=config.second_pass["rescaling_scale"], + skip_initial_inference_steps=config.second_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.second_pass["skip_final_inference_steps"], + num_inference_steps=config.second_pass["num_inference_steps"], + guidance_timesteps=config.second_pass["guidance_timesteps"], + cfg_star_rescale=config.second_pass["cfg_star_rescale"], + skip_layer_strategy=None, + skip_block_list=config.second_pass["skip_block_list"], + ) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos + + return result diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py new file mode 100644 index 000000000..c4b2657e3 --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -0,0 +1,330 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from pathlib import Path +import os +from safetensors import safe_open + +import flax +import jax +import jax.numpy as jnp +import json +from maxdiffusion.configuration_utils import ConfigMixin, register_to_config +from maxdiffusion.schedulers.scheduling_utils_flax import ( + CommonSchedulerState, + FlaxSchedulerMixin, + FlaxSchedulerOutput, +) + + +def linear_quadratic_schedule_jax( + num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None +) -> jnp.ndarray: + if num_steps == 1: + return jnp.array([1.0], dtype=jnp.float32) + if linear_steps is None: + linear_steps = num_steps // 2 + + linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps + + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_indices = jnp.arange(linear_steps, num_steps) + quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const + + sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) + sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) + sigma_schedule = 1.0 - sigma_schedule + return sigma_schedule[:-1].astype(jnp.float32) + + +def time_shift_jax(mu: float, sigma: float, t: jnp.ndarray) -> jnp.ndarray: + mu_f = jnp.array(mu, dtype=jnp.float32) + sigma_f = jnp.array(sigma, dtype=jnp.float32) + return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) + + +def _prod_jax(iterable): + return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) + + +def get_normal_shift_jax( + n_tokens: int, + min_tokens: int = 1024, + max_tokens: int = 4096, + min_shift: float = 0.95, + max_shift: float = 2.05, +) -> float: + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b + + +def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: + """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" + return x[(...,) + (None,) * (target_dims - x.ndim)] + + +def strech_shifts_to_terminal_jax(shifts: jnp.ndarray, terminal: float = 0.1) -> jnp.ndarray: + if shifts.size == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + one_minus_z = 1.0 - shifts + # Using shifts[-1] for the last element + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched_shifts = 1.0 - (one_minus_z / scale_factor) + + return stretched_shifts + + +def sd3_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + target_shift_terminal: Optional[float] = None, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + + shift = get_normal_shift_jax(int(m)) + time_shifts = time_shift_jax(shift, 1.0, timesteps) + + if target_shift_terminal is not None: + time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) + return time_shifts + + +def simple_diffusion_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + n: int = 32 * 32, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + # Ensure m and n are float32 for calculations + m_f = jnp.array(m, dtype=jnp.float32) + n_f = jnp.array(n, dtype=jnp.float32) + + snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability + shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability + shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) + + return shifted_timesteps + + +@flax.struct.dataclass +class RectifiedFlowSchedulerState: + """ + Data class to hold the mutable state of the RectifiedFlowScheduler. + """ + + common: CommonSchedulerState + init_noise_sigma: float + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + @classmethod + def create(cls, common_state: CommonSchedulerState, init_noise_sigma: float): + return cls( + common=common_state, + init_noise_sigma=init_noise_sigma, + num_inference_steps=None, + timesteps=None, + sigmas=None, + ) + + +@dataclass +class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): + state: RectifiedFlowSchedulerState + + +class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Note: shifting and stochastic sampling not tested + """ + + dtype: jnp.dtype + order = 1 + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + beta_schedule: str = "linear", + rescale_zero_terminal_snr: bool = False, + beta_start: float = 0.0001, + beta_end: float = 0.02, + shifting: Optional[str] = None, + base_resolution: int = 32**2, + target_shift_terminal: Optional[float] = None, + sampler: Optional[str] = "Uniform", + shift: Optional[float] = None, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> RectifiedFlowSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + init_noise_sigma = 1.0 + return RectifiedFlowSchedulerState.create(common_state=common, init_noise_sigma=init_noise_sigma) + + def get_initial_timesteps_jax(self, num_timesteps: int, shift: Optional[float] = None) -> jnp.ndarray: + if self.config.sampler == "Uniform": + return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + elif self.config.sampler == "LinearQuadratic": + return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) + elif self.config.sampler == "Constant": + assert shift is not None, "Shift must be provided for constant time shift sampler." + return time_shift_jax(shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype)).astype( + self.dtype + ) + else: + raise ValueError(f"Sampler {self.config.sampler} is not supported.") + + def shift_timesteps_jax(self, samples_shape: Tuple[int, ...], timesteps: jnp.ndarray) -> jnp.ndarray: + if self.config.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.target_shift_terminal) + elif self.config.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.base_resolution) + return timesteps + + def from_pretrained_jax(pretrained_model_path: Union[str, os.PathLike]): + pretrained_model_path = Path(pretrained_model_path) + config = None + if pretrained_model_path.is_file(): + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + configs = json.loads(metadata["config"]) + config = configs["scheduler"] + + elif pretrained_model_path.is_dir(): + diffusers_noise_scheduler_config_path = pretrained_model_path / "scheduler" / "scheduler_config.json" + + if not diffusers_noise_scheduler_config_path.is_file(): + raise FileNotFoundError(f"Scheduler config not found at {diffusers_noise_scheduler_config_path}") + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + config = scheduler_config + return FlaxRectifiedFlowMultistepScheduler.from_config(config) + + def set_timesteps( + self, + state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + samples_shape: Optional[Tuple[int, ...]] = None, + timesteps: Optional[jnp.ndarray] = None, + device: Optional[str] = None, + ) -> RectifiedFlowSchedulerState: + if timesteps is not None and num_inference_steps is not None: + raise ValueError("You cannot provide both `timesteps` and `num_inference_steps`.") + + # Determine the number of inference steps if not provided + if num_inference_steps is None and timesteps is None: + raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") + + if timesteps is None: + num_inference_steps = jnp.minimum(self.config.num_train_timesteps, num_inference_steps) + timesteps = self.get_initial_timesteps_jax(num_inference_steps, shift=self.config.shift).astype(self.dtype) + + # Apply shifting if samples_shape is provided and shifting is configured + if samples_shape is not None: + timesteps = self.shift_timesteps_jax(samples_shape, timesteps) + else: + timesteps = jnp.asarray(timesteps, dtype=self.dtype) + num_inference_steps = len(timesteps) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + sigmas=timesteps, # sigmas are the same as timesteps in RF + ) + + def scale_model_input( + self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + # Rectified Flow scheduler typically doesn't scale model input, returns as is. + return sample + + def step( + self, + state: RectifiedFlowSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, + sample: jnp.ndarray, + return_dict: bool = True, + stochastic_sampling: bool = False, + generator: Optional[jax.random.PRNGKey] = None, + ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: + if state.num_inference_steps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + t_eps = 1e-6 # Small epsilon for numerical issues + + timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) + + if timestep.ndim == 0: + idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side="right") # noqa: F841 + current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] + lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), timesteps_padded[current_t_idx + 1], 0.0) + dt = timestep - lower_timestep + else: + current_t_indices = jnp.searchsorted(state.timesteps, timestep, side="right") # timesteps is decreasing + current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search + lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) + lower_timestep = timesteps_padded[lower_timestep_indices] + dt = timestep - lower_timestep + dt = append_dims_jax(dt, sample.ndim) + + # Compute previous sample + if stochastic_sampling: + if generator is None: + raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") + broadcastable_timestep = append_dims_jax(timestep, sample.ndim) + + x0 = sample - broadcastable_timestep * model_output + next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep + + noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) + prev_sample = self.add_noise(state.common, x0, noise, next_timestep) + else: + prev_sample = sample - dt * model_output + + if not return_dict: + return (prev_sample, state) + + return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py new file mode 100644 index 000000000..9398c9156 --- /dev/null +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -0,0 +1,202 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import torch +import jax +import numpy as np +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from jax.sharding import Mesh +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import PartitionSpec as P +import orbax.checkpoint as ocp + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def load_ref_prediction(): + base_dir = os.path.dirname(__file__) + saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") + predict_dict = torch.load(saved_prediction_path) + noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) + return noise_pred_pt + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return latents + + +class LTXTransformerTest(unittest.TestCase): + + def test_one_step_transformer(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), + ], + unittest=True, + ) + config = pyconfig.config + noise_pred_pt = load_ref_prediction() + + # set up transformer + key = jax.random.PRNGKey(42) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + + noise_pred = p_run_inference(states).block_until_ready() + noise_pred = torch.from_numpy(np.array(noise_pred)) + + torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred new file mode 100644 index 000000000..0a9fe9120 Binary files /dev/null and b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred differ