diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 728d2f2e3..05f332fb7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed682..42e50d775 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -374,6 +374,7 @@ _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( [ @@ -453,6 +454,7 @@ from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..b83e85a87 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,11 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 000000000..0fdbe7f9f --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,101 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False +#Checkpoints +text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax" +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +frame_rate: 30 +max_sequence_length: 512 +sampler: "from_checkpoint" + + + + + +# Generation parameters +pipeline_type: multi-scale +prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." +height: 512 +width: 512 +num_frames: 88 #344 +flow_shift: 5.0 +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" +prompt_enhancement_words_threshold: 120 +stg_mode: "attention_values" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +models_dir: "/mnt/disks/diffusionproj" #where safetensor file is + + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + skip_initial_inference_steps: 0 + cfg_star_rescale: True + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + skip_final_inference_steps: 0 + cfg_star_rescale: True + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_heads', 'fsdp'], + ['activation_batch', ['data','fsdp']], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'fsdp'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + + + +learning_rate_schedule_steps: -1 +max_train_steps: 500 #TODO: change this +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 000000000..f7d7e6d03 --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,146 @@ +import numpy as np +from absl import app +from typing import Sequence +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline +from maxdiffusion import pyconfig +import imageio +from datetime import datetime +import os +import torch +from pathlib import Path + + +def calculate_padding( + source_height: int, source_width: int, target_height: int, target_width: int +) -> tuple[int, int, int, int]: + + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width + + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding + + # Return padded tensor + # Padding format is (left, right, top, bottom) + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding + + +def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: + # Remove non-letters and convert to lowercase + clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace()) + + # Split into words + words = clean_text.split() + + # Build result string keeping track of length + result = [] + current_length = 0 + + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) + + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break + + return "-".join(result) + + +def get_unique_filename( + base: str, + ext: str, + prompt: str, + seed: int, + resolution: tuple[int, int, int], + dir: Path, + endswith=None, + index_range=1000, +) -> Path: + base_filename = ( + f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + ) + for i in range(index_range): + filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.") + + +def run(config): + height_padded = ((config.height - 1) // 32 + 1) * 32 + width_padded = ((config.width - 1) // 32 + 1) * 32 + num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 + padding = calculate_padding(config.height, config.width, height_padded, width_padded) + + seed = 10 + generator = torch.Generator().manual_seed(seed) + pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False) + pipeline = LTXMultiScalePipeline(pipeline) + images = pipeline( + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + output_type="pt", + generator=generator, + config=config, + ) + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right] + output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + output_dir.mkdir(parents=True, exist_ok=True) + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = config.frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + output_filename = get_unique_filename( + f"video_output_{i}", + ".mp4", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + print(output_filename) + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..e645ecec1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24e..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -13,9 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING - -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available - +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,6 +30,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 000000000..ee7221652 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,86 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 000000000..3503ab3b4 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,125 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + # kernel_in_axis = np.arange(len(axis)) + # kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py new file mode 100644 index 000000000..cb4a6b9ce --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 000000000..31c6b5b15 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,116 @@ +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +import jax.numpy as jnp +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + """ + Integrates an input module in a jax carry format + + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: + data_input, index_input = carry + + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + + output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers + + next_index = index_input + 1 + new_carry = (output_data, next_index) + + return new_carry, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + if not args: + raise ValueError("RepeatableLayer expects at least one argument for initial data input.") + + initial_data_input = args[0] + static_block_args = args[1:] + + initial_index = jnp.array(0, dtype=jnp.int32) # index of current transformer block + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + + in_axes_for_scan = (nn.broadcast,) * (len(args) - 1) + + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True}, + in_axes=in_axes_for_scan, + length=self.num_layers, + **scan_kwargs, + ) + + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + + # Call wrapped_function with the initial carry tuple and the static_block_args + (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) + + return final_data diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 000000000..4ae1d9a00 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,189 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + # Mish is not in JAX by default + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + # if len(args) > 0 or kwargs.get("scale", None) is not None: + # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + # deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 000000000..e9b287649 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,209 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 000000000..75692b703 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,923 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from functools import partial +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + index: int, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + skip_layer_strategy = SkipLayerStrategy.AttentionValues + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + + attn_output = self.attn1( + norm_hidden_states, + block_index=index, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + block_index=-1, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + block_index: int = -1, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821 + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape( + skip_layer_mask[block_index], (batch_size, 1, 1) + ) # here skip_layer_mask is (48,3), changed this currently! + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionValues: + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkvo_sharding_spec = jax.sharding.PartitionSpec( + "data", + "fsdp", + None, + "tensor", + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 000000000..d8240989c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,56 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 000000000..1c1807fdd --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,326 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( #noqa: C408 + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + key, + **example_inputs, + )["params"] + else: + return self.init(key, **example_inputs)["params"] + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ) -> Optional[jnp.ndarray]: + if skip_block_list is None or len(skip_block_list) == 0: + return None + mask = jnp.ones((self.num_layers, batch_size * num_conds), dtype=self.dtype) + + for block_idx in skip_block_list: + mask = mask.at[block_idx, ptb_index::num_conds].set(0) + + return mask + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + skip_layer_mask=None, + skip_layer_strategy=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + skip_layer_mask, + skip_layer_strategy, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/utils/__init__.py b/src/maxdiffusion/models/ltx_video/utils/__init__.py new file mode 100644 index 000000000..cb4a6b9ce --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 000000000..bce38fb20 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,26 @@ +{ + "ckpt_path": "/mnt/disks/diffusionproj/jax_weights", + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, + "in_channels": 128 +} \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/ltx_video/__init__.py b/src/maxdiffusion/pipelines/ltx_video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py new file mode 100644 index 000000000..c4767e8ee --- /dev/null +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -0,0 +1,1018 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +from jax import Array +from maxdiffusion.models.ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler +from diffusers import AutoencoderKL +from typing import Optional, List, Union, Tuple +from einops import rearrange +import torch.nn.functional as F +from diffusers.utils.torch_utils import randn_tensor +from transformers import ( + FlaxT5EncoderModel, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) +import json +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from maxdiffusion.models.ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.models.autoencoders.vae_encode import ( + get_vae_size_scale_factor, + latent_to_pixel_coords, + vae_decode, + vae_encode, + un_normalize_latents, + normalize_latents, +) +from diffusers.image_processor import VaeImageProcessor +from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt +from types import NoneType +from typing import Any, Dict + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier +from ...pyconfig import HyperParameters +from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState +from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +import orbax.checkpoint as ocp + + +def prepare_extra_step_kwargs(generator): + extra_step_kwargs = {} + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + +class LTXVideoPipeline: + + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: AutoencoderKL, + text_encoder, + patchifier, + tokenizer, + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, + ): + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(self.vae) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True + ) + ##load in jax weights checkpoint + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + return vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + t5_tokenizer = AutoTokenizer.from_pretrained(ckpt_path, max_length=config.max_sequence_length, use_fast=True) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer(config) + + # load from pytorch version + models_dir = config.models_dir + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + vae = cls.load_vae(ltxv_model_path) + vae = vae.to(torch.bfloat16) + text_encoder = cls.load_text_encoder(config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer(config, config.text_encoder_model_name_or_path) + + enhance_prompt = False + if enhance_prompt: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = cls.load_prompt_enhancement(config) + else: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = (None, None, None, None) + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings, + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( # currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, + latent_shape, + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps( + state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps + ) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps >= num_inference_steps + ): + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps] + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape=latent_shape, state=scheduler_state) + + return scheduler_state + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = text_encoder_max_tokens + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) #noqa: F841 + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype #noqa: F841 + elif self.transformer is not None: + dtype = self.transformer.dtype #noqa: F841 + else: + dtype = None #noqa: F841 + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile(prompt_attention_mask, (1, num_images_per_prompt)) + prompt_attention_mask = jnp.reshape(prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, (1, num_images_per_prompt, 1)) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, (1, num_images_per_prompt)) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + def prepare_latents( ## this is in pytorch + self, + latents: torch.Tensor | None, + media_items: torch.Tensor | None, + timestep: float, + latent_shape: torch.Size | Tuple[Any, ...], + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | List[torch.Generator], + vae_per_channel_normalize: bool = True, + ): + if isinstance(generator, list) and len(generator) != latent_shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." + ) + + # Initialize the latents with the given latents or encoded media item, if provided + assert ( + latents is None or media_items is None + ), "Cannot provide both latents and media_items. Please provide only one of the two." + + assert ( + latents is None and media_items is None or timestep < 1.0 + ), "Input media_item or latents are provided, but they will be replaced with noise." + + if media_items is not None: + latents = vae_encode( + media_items, + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + if latents is not None: + assert latents.shape == latent_shape, f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + noise = randn_tensor((b, f * h * w, c), generator=generator, device=device, dtype=dtype) + noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) + + # scale the initial noise by the standard deviation required by the scheduler + # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + timestep = torch.from_numpy(np.array(timestep)) + latents = timestep * noise + (1 - timestep) * latents + + return latents + + def prepare_conditioning( # removed conditioning_item logic + self, + conditioning_items, + init_latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + vae_per_channel_normalize: bool = True, + generator=None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + + assert isinstance(self.vae, CausalVideoAutoencoder) + + # Patchify the updated latents and calculate their pixel coordinates + init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) + init_pixel_coords = latent_to_pixel_coords( + init_latent_coords, + self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now + causal_fix=True, + ) + + if not conditioning_items: + return init_latents, init_pixel_coords, None, 0 + + def __call__( + self, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + frame_rate: int = 30, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + enhance_prompt = False + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True) + import pdb + + pdb.set_trace() + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, CausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps( + self.scheduler, + latent_shape, + self.scheduler_state, + num_inference_steps, + None, + skip_initial_inference_steps, + skip_final_inference_steps, + ) + + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [i for i, val in enumerate(guidance_timesteps) if val <= timestep] + guidance_mapping.append(indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask(batch_size, num_conds, num_conds - 1, skip_blocks) + for skip_blocks in skip_block_list + ] + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, # device set to none + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) + latents = self.prepare_latents( + latents=latents, + media_items=None, # set to None + timestep=scheduler_state.timesteps[0], + latent_shape=latent_shape, + dtype=None, + device=None, + generator=generator, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + latents, pixel_coords, conditioning_mask, num_cond_latents = self.prepare_conditioning( + conditioning_items=None, + init_latents=latents, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=vae_per_channel_normalize, + generator=generator, + ) + + + pixel_coords = torch.cat([pixel_coords] * num_conds) + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + noise_cond = jnp.ones((1, 1)) # initialize first round with this! + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=jnp.array(fractional_coords.to(torch.float32).detach().numpy()), + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling=do_rescaling, + rescaling_scale=rescaling_scale, + batch_size=batch_size, + skip_layer_masks=skip_layer_masks, + cfg_star_rescale=cfg_star_rescale, + ) + + with self.mesh: + latents, scheduler_state = p_run_inference( + transformer_state=self.transformer_state, + latents=jnp.array(latents.to(torch.float32).detach().numpy()), + timestep=noise_cond, + scheduler_state=scheduler_state, + ) + latents = torch.from_numpy(np.array(latents)) + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size), + ) + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor(decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to(latents.device)[:, None, None, None, None] + latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale + else: + decode_timestep = None + image = vae_decode( + latents, + self.vae, + is_video, + vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), + timestep=decode_timestep, + ) + image = self.image_processor.postprocess(image, output_type=output_type) + + else: + image = latents + + # Offload all models + + if not return_dict: + return (image,) + + return image + + +def transformer_forward_pass( + latents, + state, + noise_cond, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask, +): + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + ) + return noise_pred, state + + +def run_inference( + transformer_state, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + num_inference_steps, + scheduler, + segment_ids, + encoder_attention_segment_ids, + scheduler_state, + do_classifier_free_guidance, + num_conds, + guidance_scale, + do_spatio_temporal_guidance, + stg_scale, + do_rescaling, + rescaling_scale, + batch_size, + skip_layer_masks, + cfg_star_rescale, +): + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) + + # Broadcast to batch dimension + current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) + + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, + transformer_state, + current_timestep, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask=(skip_layer_masks[i] if skip_layer_masks is not None else None), + ) + + if do_spatio_temporal_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) + squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + alpha = dot_product / squared_norm + alpha = alpha.reshape(batch_size, 1, 1) + + noise_pred_uncond = alpha * noise_pred_uncond + noise_pred = noise_pred_uncond + guidance_scale[i] * (noise_pred_text - noise_pred_uncond) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * (noise_pred_text - noise_pred_text_perturb) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state + + +def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + +class LTXMultiScalePipeline: + + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir=config.models_dir, + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler + + def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Tensor): + assert latents.device == latest_upsampler.device + + latents = un_normalize_latents(latents, self.vae, vae_per_channel_normalize=True) + upsampled_latents = latest_upsampler(latents) + upsampled_latents = normalize_latents(upsampled_latents, self.vae, vae_per_channel_normalize=True) + return upsampled_latents + + def __init__(self, video_pipeline: LTXVideoPipeline): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + + def __call__(self, height, width, num_frames, output_type, generator, config) -> Any: + + latent_upsampler = self.load_latent_upsampler(config) + original_output_type = output_type + output_type = "latent" + result = self.video_pipeline( + height=height, + width=width, + num_frames=num_frames, + is_video=True, + output_type=output_type, + generator=generator, + guidance_scale=config.first_pass["guidance_scale"], + stg_scale=config.first_pass["stg_scale"], + rescaling_scale=config.first_pass["rescaling_scale"], + skip_initial_inference_steps=config.first_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.first_pass["skip_final_inference_steps"], + num_inference_steps=config.first_pass["num_inference_steps"], + guidance_timesteps=config.first_pass["guidance_timesteps"], + cfg_star_rescale=config.first_pass["cfg_star_rescale"], + skip_block_list=config.first_pass["skip_block_list"], + ) + latents = result + upsampled_latents = self._upsample_latents(latent_upsampler, latents) + upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) + + latents = upsampled_latents + output_type = original_output_type + + result = self.video_pipeline( + height=height * 2, + width=width * 2, + num_frames=num_frames, + is_video=True, + output_type=output_type, + latents=latents, + generator=generator, + guidance_scale=config.second_pass["guidance_scale"], + stg_scale=config.second_pass["stg_scale"], + rescaling_scale=config.second_pass["rescaling_scale"], + skip_initial_inference_steps=config.second_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.second_pass["skip_final_inference_steps"], + num_inference_steps=config.second_pass["num_inference_steps"], + guidance_timesteps=config.second_pass["guidance_timesteps"], + cfg_star_rescale=config.second_pass["cfg_star_rescale"], + skip_block_list=config.second_pass["skip_block_list"], + ) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos + + return result diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py new file mode 100644 index 000000000..b550aeea3 --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -0,0 +1,327 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from pathlib import Path +import os +from safetensors import safe_open + +import flax +import jax +import jax.numpy as jnp +import json +from maxdiffusion.configuration_utils import ConfigMixin, register_to_config +from maxdiffusion.schedulers.scheduling_utils_flax import ( + CommonSchedulerState, + FlaxSchedulerMixin, + FlaxSchedulerOutput, +) + + +def linear_quadratic_schedule_jax( + num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None +) -> jnp.ndarray: + if num_steps == 1: + return jnp.array([1.0], dtype=jnp.float32) + if linear_steps is None: + linear_steps = num_steps // 2 + + linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps + + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_indices = jnp.arange(linear_steps, num_steps) + quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const + + sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) + sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) + sigma_schedule = 1.0 - sigma_schedule + return sigma_schedule[:-1].astype(jnp.float32) + + +def time_shift_jax(mu: float, sigma: float, t: jnp.ndarray) -> jnp.ndarray: + mu_f = jnp.array(mu, dtype=jnp.float32) + sigma_f = jnp.array(sigma, dtype=jnp.float32) + return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) + + +def _prod_jax(iterable): + return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) + + +def get_normal_shift_jax( + n_tokens: int, + min_tokens: int = 1024, + max_tokens: int = 4096, + min_shift: float = 0.95, + max_shift: float = 2.05, +) -> float: + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b + + +def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: + """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" + return x[(...,) + (None,) * (target_dims - x.ndim)] + + +def strech_shifts_to_terminal_jax(shifts: jnp.ndarray, terminal: float = 0.1) -> jnp.ndarray: + if shifts.size == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + one_minus_z = 1.0 - shifts + # Using shifts[-1] for the last element + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched_shifts = 1.0 - (one_minus_z / scale_factor) + + return stretched_shifts + + +def sd3_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + target_shift_terminal: Optional[float] = None, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + + shift = get_normal_shift_jax(int(m)) + time_shifts = time_shift_jax(shift, 1.0, timesteps) + + if target_shift_terminal is not None: + time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) + return time_shifts + + +def simple_diffusion_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + n: int = 32 * 32, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + # Ensure m and n are float32 for calculations + m_f = jnp.array(m, dtype=jnp.float32) + n_f = jnp.array(n, dtype=jnp.float32) + + snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability + shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability + shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) + + return shifted_timesteps + + +@flax.struct.dataclass +class RectifiedFlowSchedulerState: + """ + Data class to hold the mutable state of the RectifiedFlowScheduler. + """ + + common: CommonSchedulerState + init_noise_sigma: float + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + @classmethod + def create(cls, common_state: CommonSchedulerState, init_noise_sigma: float): + return cls( + common=common_state, + init_noise_sigma=init_noise_sigma, + num_inference_steps=None, + timesteps=None, + sigmas=None, + ) + + +@dataclass +class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): + state: RectifiedFlowSchedulerState + + +class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + + dtype: jnp.dtype + order = 1 + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + beta_schedule: str = "linear", + rescale_zero_terminal_snr: bool = False, + beta_start: float = 0.0001, + beta_end: float = 0.02, + shifting: Optional[str] = None, + base_resolution: int = 32**2, + target_shift_terminal: Optional[float] = None, + sampler: Optional[str] = "Uniform", + shift: Optional[float] = None, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> RectifiedFlowSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + init_noise_sigma = 1.0 + return RectifiedFlowSchedulerState.create(common_state=common, init_noise_sigma=init_noise_sigma) + + def get_initial_timesteps_jax(self, num_timesteps: int, shift: Optional[float] = None) -> jnp.ndarray: + if self.config.sampler == "Uniform": + return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + elif self.config.sampler == "LinearQuadratic": + return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) + elif self.config.sampler == "Constant": + assert shift is not None, "Shift must be provided for constant time shift sampler." + return time_shift_jax(shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype)).astype( + self.dtype + ) + else: + raise ValueError(f"Sampler {self.config.sampler} is not supported.") + + def shift_timesteps_jax(self, samples_shape: Tuple[int, ...], timesteps: jnp.ndarray) -> jnp.ndarray: + if self.config.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.target_shift_terminal) + elif self.config.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.base_resolution) + return timesteps + + def from_pretrained_jax(pretrained_model_path: Union[str, os.PathLike]): + pretrained_model_path = Path(pretrained_model_path) + config = None + if pretrained_model_path.is_file(): + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + configs = json.loads(metadata["config"]) + config = configs["scheduler"] + + elif pretrained_model_path.is_dir(): + diffusers_noise_scheduler_config_path = pretrained_model_path / "scheduler" / "scheduler_config.json" + + if not diffusers_noise_scheduler_config_path.is_file(): + raise FileNotFoundError(f"Scheduler config not found at {diffusers_noise_scheduler_config_path}") + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + config = scheduler_config + return FlaxRectifiedFlowMultistepScheduler.from_config(config) + + def set_timesteps( + self, + state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + samples_shape: Optional[Tuple[int, ...]] = None, + timesteps: Optional[jnp.ndarray] = None, + device: Optional[str] = None, + ) -> RectifiedFlowSchedulerState: + if timesteps is not None and num_inference_steps is not None: + raise ValueError("You cannot provide both `timesteps` and `num_inference_steps`.") + + # Determine the number of inference steps if not provided + if num_inference_steps is None and timesteps is None: + raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") + + if timesteps is None: + num_inference_steps = jnp.minimum(self.config.num_train_timesteps, num_inference_steps) + timesteps = self.get_initial_timesteps_jax(num_inference_steps, shift=self.config.shift).astype(self.dtype) + + # Apply shifting if samples_shape is provided and shifting is configured + if samples_shape is not None: + timesteps = self.shift_timesteps_jax(samples_shape, timesteps) + else: + timesteps = jnp.asarray(timesteps, dtype=self.dtype) + num_inference_steps = len(timesteps) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + sigmas=timesteps, # sigmas are the same as timesteps in RF + ) + + def scale_model_input( + self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + # Rectified Flow scheduler typically doesn't scale model input, returns as is. + return sample + + def step( + self, + state: RectifiedFlowSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, + sample: jnp.ndarray, + return_dict: bool = True, + stochastic_sampling: bool = False, + generator: Optional[jax.random.PRNGKey] = None, + ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: + if state.num_inference_steps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + t_eps = 1e-6 # Small epsilon for numerical issues + + timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) + + if timestep.ndim == 0: + idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side="right") #noqa: F841 + current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] + lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), timesteps_padded[current_t_idx + 1], 0.0) + dt = timestep - lower_timestep + else: + current_t_indices = jnp.searchsorted(state.timesteps, timestep, side="right") # timesteps is decreasing + current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search + lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) + lower_timestep = timesteps_padded[lower_timestep_indices] + dt = timestep - lower_timestep + dt = append_dims_jax(dt, sample.ndim) + + # Compute previous sample + if stochastic_sampling: + if generator is None: + raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") + broadcastable_timestep = append_dims_jax(timestep, sample.ndim) + + x0 = sample - broadcastable_timestep * model_output + next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep + + noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) + prev_sample = self.add_noise(state.common, x0, noise, next_timestep) + else: + prev_sample = sample - dt * model_output + + if not return_dict: + return (prev_sample, state) + + return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py new file mode 100644 index 000000000..9398c9156 --- /dev/null +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -0,0 +1,202 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import torch +import jax +import numpy as np +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from jax.sharding import Mesh +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import PartitionSpec as P +import orbax.checkpoint as ocp + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def load_ref_prediction(): + base_dir = os.path.dirname(__file__) + saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") + predict_dict = torch.load(saved_prediction_path) + noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) + return noise_pred_pt + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return latents + + +class LTXTransformerTest(unittest.TestCase): + + def test_one_step_transformer(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), + ], + unittest=True, + ) + config = pyconfig.config + noise_pred_pt = load_ref_prediction() + + # set up transformer + key = jax.random.PRNGKey(42) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + + noise_pred = p_run_inference(states).block_until_ready() + noise_pred = torch.from_numpy(np.array(noise_pred)) + + torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred new file mode 100644 index 000000000..0a9fe9120 Binary files /dev/null and b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred differ