diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 728d2f2e3..05f332fb7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed682..42e50d775 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -374,6 +374,7 @@ _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( [ @@ -453,6 +454,7 @@ from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..27046a0b6 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,11 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec @@ -248,3 +251,4 @@ def map_to_pspec(data): except: max_logging.log(f"could not load {checkpoint_item} from orbax") return None + diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 000000000..140fda56c --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_heads', 'fsdp'], + ['activation_batch', ['data','fsdp']], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'fsdp'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + + + + +learning_rate_schedule_steps: -1 +max_train_steps: 500 #TODO: change this +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True +enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 000000000..2dec16fa6 --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,202 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from absl import app +from typing import Sequence +import jax +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import functools +import jax.numpy as jnp +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import Mesh, PartitionSpec as P +import orbax.checkpoint as ocp + + +def validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids +): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return noise_pred + + +def run(config): + key = jax.random.PRNGKey(42) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + base_dir = os.path.dirname(__file__) + + ##load in model config + config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + + # create dummy inputs: + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids + ) + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + + noise_pred = p_run_inference(states).block_until_ready() + print(noise_pred) # (4, 256, 128) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..e645ecec1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24e..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -13,9 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING - -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available - +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,6 +30,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 000000000..ee7221652 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,86 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 000000000..3503ab3b4 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,125 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + # kernel_in_axis = np.arange(len(axis)) + # kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 000000000..7e9cc80c4 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,118 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + """ + Integrates an input module in a jax carry format + + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles + + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 000000000..4ae1d9a00 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,189 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + # Mish is not in JAX by default + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + # if len(args) > 0 or kwargs.get("scale", None) is not None: + # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + # deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 000000000..e9b287649 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,209 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 000000000..9faab1ded --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,915 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from functools import partial +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto + +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + qkvo_sharding_spec = jax.sharding.PartitionSpec( + "data", + "fsdp", + None, + "tensor", + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 000000000..d8240989c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,56 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 000000000..d6c7cf4c4 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,305 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( # noqa C408 + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + key, + **example_inputs, + )["params"] + else: + return self.init(key, **example_inputs)["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 000000000..c5b3c0ef9 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,26 @@ +{ + "ckpt_path": "", + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, + "in_channels": 128 +} \ No newline at end of file diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py new file mode 100644 index 000000000..9398c9156 --- /dev/null +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -0,0 +1,202 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import torch +import jax +import numpy as np +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from jax.sharding import Mesh +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import PartitionSpec as P +import orbax.checkpoint as ocp + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def load_ref_prediction(): + base_dir = os.path.dirname(__file__) + saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") + predict_dict = torch.load(saved_prediction_path) + noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) + return noise_pred_pt + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return latents + + +class LTXTransformerTest(unittest.TestCase): + + def test_one_step_transformer(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), + ], + unittest=True, + ) + config = pyconfig.config + noise_pred_pt = load_ref_prediction() + + # set up transformer + key = jax.random.PRNGKey(42) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + + noise_pred = p_run_inference(states).block_until_ready() + noise_pred = torch.from_numpy(np.array(noise_pred)) + + torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred new file mode 100644 index 000000000..0a9fe9120 Binary files /dev/null and b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred differ