From 6552f146ce64d7608da8b7b672e29db036f9d124 Mon Sep 17 00:00:00 2001 From: Serena Date: Mon, 9 Jun 2025 17:19:35 +0000 Subject: [PATCH 1/4] model setup --- setup.sh | 2 +- src/maxdiffusion/__init__.py | 2 + src/maxdiffusion/configs/ltx_video.yml | 50 + src/maxdiffusion/generate_ltx_video.py | 73 ++ src/maxdiffusion/models/__init__.py | 4 +- src/maxdiffusion/models/ltx_video/__init__.py | 0 .../models/ltx_video/gradient_checkpoint.py | 70 ++ src/maxdiffusion/models/ltx_video/linear.py | 109 +++ src/maxdiffusion/models/ltx_video/main.py | 40 + .../models/ltx_video/repeatable_layer.py | 102 ++ .../models/ltx_video/transformers/__init__.py | 0 .../ltx_video/transformers/activations.py | 173 ++++ .../models/ltx_video/transformers/adaln.py | 194 ++++ .../ltx_video/transformers/attention.py | 899 ++++++++++++++++++ .../transformers/caption_projection.py | 40 + .../ltx_video/transformers/transformer3d.py | 311 ++++++ .../ltx_video/xora_v1.2-13B-balanced-128.json | 24 + 17 files changed, 2090 insertions(+), 3 deletions(-) create mode 100644 src/maxdiffusion/configs/ltx_video.yml create mode 100644 src/maxdiffusion/generate_ltx_video.py create mode 100644 src/maxdiffusion/models/ltx_video/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/gradient_checkpoint.py create mode 100644 src/maxdiffusion/models/ltx_video/linear.py create mode 100644 src/maxdiffusion/models/ltx_video/main.py create mode 100644 src/maxdiffusion/models/ltx_video/repeatable_layer.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/activations.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/adaln.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/attention.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/caption_projection.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/transformer3d.py create mode 100644 src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json diff --git a/setup.sh b/setup.sh index fc3f640e0..9beb11d23 100644 --- a/setup.sh +++ b/setup.sh @@ -110,4 +110,4 @@ else fi # Install maxdiffusion -pip3 install -U . || echo "Failed to install maxdiffusion" >&2 +pip3 install -e . || echo "Failed to install maxdiffusion" >&2 diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed682..1cfd62930 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -373,6 +373,7 @@ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( @@ -453,6 +454,7 @@ from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 000000000..798cd9f83 --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,50 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + + + + +learning_rate_schedule_steps: -1 +max_train_steps: 500 #TODO: change this +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 000000000..7cd2b26c5 --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,73 @@ +from absl import app +from typing import Sequence +import jax +import json +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import functools +import jax.numpy as jnp +from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, +) +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + +def run(config): + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 + base_dir = os.path.dirname(__file__) + + ##load in model config + config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + + transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only = False) + + key, split_key = jax.random.split(key) + weights_init_fn = functools.partial( + transformer.init_weights, + split_key, + batch_size, + text_tokens, + num_tokens, + features, + eval_only = False + ) + + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) + + + diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24e..f03bc306f 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,7 +32,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 000000000..f32cc9459 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,70 @@ +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 000000000..4fdc1444a --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,109 @@ +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/main.py b/src/maxdiffusion/models/ltx_video/main.py new file mode 100644 index 000000000..f21cc5e46 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/main.py @@ -0,0 +1,40 @@ +import os +import jax +import jax.numpy as jnp +import json + + +from models.transformers.transformer3d import Transformer3DModel + +# Load JSON config +base_dir = os.path.dirname(__file__) +config_path = os.path.join(base_dir, "xora_v1.2-13B-balanced-128.json") +with open(config_path, "r") as f: + model_config = json.load(f) + +key = jax.random.PRNGKey(0) +model = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + +batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 +prompt_embeds = jax.random.normal(key, shape=(batch_size, text_tokens, features), dtype=jnp.bfloat16) +fractional_coords = jax.random.normal(key, shape=(batch_size, 3, num_tokens), dtype=jnp.bfloat16) +latents = jax.random.normal(key, shape=(batch_size, num_tokens, features), dtype=jnp.bfloat16) +noise_cond = jax.random.normal(key, shape=(batch_size, 1), dtype=jnp.bfloat16) + +model_params = model.init( + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + rngs={"params": key} +) + +output = model.apply( + model_params, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, +) + +print("done!") diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 000000000..d723ca00c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,102 @@ +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + """ + Integrates an input module in a jax carry format + + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles + + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 000000000..2a4f7180e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,173 @@ +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from diffusers.utils.deprecation_utils import deprecate + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), # Mish is not in JAX by default + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 000000000..1ca38f701 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,194 @@ +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + emb = timesteps[..., None] * emb # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = scale * emb + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) # Shape (*timesteps.shape, embedding_dim) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + + """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 000000000..e4c3351ee --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,899 @@ +from functools import partial +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto + +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint( + attn_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint( + attn_input, ("activation_batch", "activation_norm_length", "activation_embed") + ) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint( + ff_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = { k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters } + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention( + query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype + ) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + qkvo_sharding_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), + ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 000000000..dff8b8c62 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,40 @@ +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 000000000..2e8d86b97 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,311 @@ +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): + + #bookkeeping, for convenient changes later + latents_shape = (batch_size, num_tokens, features) + fractional_cords_shape = (batch_size, 3, num_tokens) + prompt_embeds_shape = (batch_size, text_tokens, features) + noise_cond_shape = (batch_size, 1) + latents_dtype = jnp.bfloat16 + fractional_coords_dtype = jnp.bfloat16 + prompt_embeds_dtype = jnp.bfloat16 + noise_cond_dtype = jnp.bfloat16 + + #initialize to random + key, split_key = jax.random.split(key) + prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) + key, split_key = jax.random.split(key) + fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) + key, split_key = jax.random.split(key) + latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) + key, split_key = jax.random.split(key) + noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) + + + key, split_key = jax.random.split(key) + if eval_only: + return jax.eval_shape( + self.init, + rngs = {"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states = prompt_embeds, + timestep=noise_cond, + )["params"] + else: + return self.init( + rngs = {"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states = prompt_embeds, + timestep=noise_cond, + )["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = ( + self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + ) + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + + + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 000000000..770093859 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,24 @@ +{ + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000 +} From 0eb3303c7e230fbae0eb53193779fa01ab2a5c08 Mon Sep 17 00:00:00 2001 From: Serena Date: Thu, 12 Jun 2025 20:14:18 +0000 Subject: [PATCH 2/4] conversion done, fixing sharding issue --- .../checkpointing/checkpointing_utils.py | 9 +- src/maxdiffusion/configs/ltx_video.yml | 1 + src/maxdiffusion/generate_ltx_video.py | 185 ++- src/maxdiffusion/max_utils.py | 6 +- src/maxdiffusion/models/ltx_video/main.py | 44 +- .../ltx_video/transformers/attention.py | 26 +- .../ltx_video/transformers/transformer3d.py | 55 +- .../transformers_pytorch/__init__.py | 0 .../transformers_pytorch/attention.py | 1265 +++++++++++++++++ .../transformers_pytorch/embeddings.py | 129 ++ .../symmetric_patchifier.py | 84 ++ .../transformers_pytorch/transformer_pt.py | 507 +++++++ .../utils/convert_torch_weights_to_jax.py | 256 ++++ .../utils/diffusers_config_mapping.py | 174 +++ .../ltx_video/utils/skip_layer_strategy.py | 8 + .../models/ltx_video/utils/torch_compat.py | 519 +++++++ .../ltx_video/xora_v1.2-13B-balanced-128.json | 4 +- .../tests/ltx_video_transformer_test.py | 306 ++++ test_lightning.png | Bin 1343016 -> 0 bytes 19 files changed, 3487 insertions(+), 91 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer_pt.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/torch_compat.py create mode 100644 src/maxdiffusion/tests/ltx_video_transformer_test.py delete mode 100644 test_lightning.png diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..7366cf86d 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,13 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + # item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + # return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 798cd9f83..b60529884 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -48,3 +48,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True +enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 7cd2b26c5..3491b4223 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,24 +1,76 @@ +from json import encoder from absl import app from typing import Sequence import jax +from flax import linen as nn import json +from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os import functools import jax.numpy as jnp -from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging +from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, setup_initial_state, + get_memory_allocations, ) -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +from jax.sharding import Mesh, PartitionSpec as P +import orbax.checkpoint as ocp -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) print("latents.shape: ", latents.shape, latents.dtype) print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def loop_body( + step, + args, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids +): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids + ) + import pdb; pdb.set_trace() + return noise_pred, state, noise_cond #need to make changes here? latents need to be changed based on noise_pred, but needs scheduler, return noise_pred for now + + + +def run_inference( + states, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, segment_ids, encoder_attention_segment_ids +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids + ) + ## TODO: add vae decode step + ## TODO: add loop + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return latents + def run(config): key = jax.random.PRNGKey(0) @@ -26,39 +78,134 @@ def run(config): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) ##load in model config config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "in_channels", "ckpt_path"] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] - transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only = False) + + transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) + transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True) #use this to test! - key, split_key = jax.random.split(key) weights_init_fn = functools.partial( transformer.init_weights, - split_key, - batch_size, - text_tokens, - num_tokens, - features, - eval_only = False + in_channels, + model_config['caption_channels'], + eval_only = True ) - transformer_state, transformer_state_shardings = setup_initial_state( + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( model=transformer, tx=None, config=config, mesh=mesh, weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", model_params=None, training=False, ) + + + + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + + #create dummy inputs: + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), #TODO: add in the segment id stuff + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep = noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + noise_pred = p_run_inference(states).block_until_ready() + print(noise_pred) #(4, 256, 128) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def main(argv: Sequence[str]) -> None: @@ -71,3 +218,13 @@ def main(argv: Sequence[str]) -> None: + +###setup_initial_state, can optionally load from checkpoint + + + + + + + +#end to end steps from ltx repo: pipeline_ltx_video.py diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..c04aab349 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/ltx_video/main.py b/src/maxdiffusion/models/ltx_video/main.py index f21cc5e46..9ed184cd5 100644 --- a/src/maxdiffusion/models/ltx_video/main.py +++ b/src/maxdiffusion/models/ltx_video/main.py @@ -1,40 +1,22 @@ + +import argparse +import json +from typing import Any, Dict, Optional import os import jax import jax.numpy as jnp -import json +import jax.lib.xla_extension +import flax +from flax.training import train_state +import torch +import optax +import orbax.checkpoint as ocp +from safetensors.torch import load_file +from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT -from models.transformers.transformer3d import Transformer3DModel - -# Load JSON config base_dir = os.path.dirname(__file__) config_path = os.path.join(base_dir, "xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) - -key = jax.random.PRNGKey(0) -model = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - -batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 -prompt_embeds = jax.random.normal(key, shape=(batch_size, text_tokens, features), dtype=jnp.bfloat16) -fractional_coords = jax.random.normal(key, shape=(batch_size, 3, num_tokens), dtype=jnp.bfloat16) -latents = jax.random.normal(key, shape=(batch_size, num_tokens, features), dtype=jnp.bfloat16) -noise_cond = jax.random.normal(key, shape=(batch_size, 1), dtype=jnp.bfloat16) - -model_params = model.init( - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - rngs={"params": key} -) - -output = model.apply( - model_params, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, -) - -print("done!") +transformer = Transformer3DModel_PT.from_config(model_config) \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index e4c3351ee..543981390 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -1,8 +1,8 @@ from functools import partial +import functools import math from typing import Any, Dict, Optional, Tuple from enum import Enum, auto - import jax import jax.nn as jnn import jax.numpy as jnp @@ -604,7 +604,8 @@ def __call__( block_sizes = self.default_block_sizes(q, k, dtype) scale_factor = 1 / math.sqrt(q.shape[-1]) - + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): s = ( # flash attention expects segment ids to be float32 @@ -630,14 +631,27 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, None, None, ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + #Based on: ("activation_kv_batch", "activation_length") + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 2e8d86b97..5c087f42a 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -137,47 +137,30 @@ def scale_shift_table_init(key): weight_dtype=self.weight_dtype, matmul_precision=self.matmul_precision, ) - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - #bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - #initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - - key, split_key = jax.random.split(key) + def init_weights(self, in_channels, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + if eval_only: return jax.eval_shape( self.init, - rngs = {"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states = prompt_embeds, - timestep=noise_cond, + jax.random.PRNGKey(42), ##need to change? + **example_inputs, )["params"] else: - return self.init( - rngs = {"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states = prompt_embeds, - timestep=noise_cond, - )["params"] + return self.init(jax.random.PRNGKey(42), **example_inputs)['params'] def __call__( self, diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py new file mode 100644 index 000000000..55c5bf371 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -0,0 +1,1265 @@ +import inspect +from importlib import import_module +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention import _chunked_feed_forward +from diffusers.models.attention_processor import ( + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SpatialNorm, +) +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import RMSNorm +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +try: + from torch_xla.experimental.custom_kernel import flash_attention +except ImportError: + # workaround for automatic tests. Currently this function is manually patched + # to the torch_xla lib on setup of container + pass + +# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): + The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): + The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_eps: float = 1e-5, + qk_norm: Optional[str] = None, + final_dropout: bool = False, + attention_type: str = "default", # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_tpu_flash_attention = use_tpu_flash_attention + self.adaptive_norm = adaptive_norm + + assert standardization_norm in ["layer_norm", "rms_norm"] + assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] + + make_norm_layer = ( + nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = make_norm_layer( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) # is self-attn if encoder_hidden_states is none + + if adaptive_norm == "none": + self.attn2_norm = make_norm_layer( + dim, norm_eps, norm_elementwise_affine + ) + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 5. Scale-shift for PixArt-Alpha. + if adaptive_norm != "none": + num_ada_params = 4 if adaptive_norm == "single_scale" else 6 + self.scale_shift_table = nn.Parameter( + torch.randn(num_ada_params, dim) / dim**0.5 + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + self.use_tpu_flash_attention = True + self.attn1.set_use_tpu_flash_attention() + self.attn2.set_use_tpu_flash_attention() + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." + ) + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + original_hidden_states = hidden_states + + norm_hidden_states = self.norm1(hidden_states) + + # Apply ada_norm_single + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ada_values.unbind(dim=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + norm_hidden_states = norm_hidden_states.squeeze( + 1 + ) # TODO: Check if this is needed + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.adaptive_norm == "none": + attn_input = self.attn2_norm(hidden_states) + else: + attn_input = hidden_states + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.TransformerBlock + ): + skip_layer_mask = skip_layer_mask.view(-1, 1, 1) + hidden_states = hidden_states * skip_layer_mask + original_hidden_states * ( + 1.0 - skip_layer_mask + ) + + return hidden_states + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + qk_norm: Optional[str] = None, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.use_tpu_flash_attention = use_tpu_flash_attention + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + if qk_norm is None: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) + self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) + elif qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + else: + raise ValueError(f"Unsupported qk_norm method: {qk_norm}") + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True + ) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm( + f_channels=query_dim, zq_channels=spatial_norm_dim + ) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. + """ + self.use_tpu_flash_attention = True + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + + self.processor = processor + + def get_processor( + self, return_deprecated_lora: bool = False + ) -> "AttentionProcessor": # noqa: F821 + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr( + import_module(__name__), "LoRA" + non_lora_processor_cls_name + ) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict( + self.add_k_proj.lora_layer.state_dict() + ) + lora_processor.add_v_proj_lora.load_state_dict( + self.add_v_proj.lora_layer.state_dict() + ) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set( + inspect.signature(self.processor.__call__).parameters.keys() + ) + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by" + f" {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters + } + + return self.processor( + self, + hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape( + batch_size, seq_len * extra_dim, head_size, dim // head_size + ) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len * extra_dim, dim // head_size + ) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor + ) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @staticmethod + def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + query = attn.q_norm(query) + + if encoder_hidden_states is not None: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + key = attn.k_norm(key) + else: # if no context provided do self-attention + encoder_hidden_states = hidden_states + key = attn.to_k(hidden_states) + key = attn.k_norm(key) + if attn.use_rope: + key = attn.apply_rotary_emb(key, freqs_cis) + query = attn.apply_rotary_emb(query, freqs_cis) + + value = attn.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + + if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' + q_segment_indexes = None + if ( + attention_mask is not None + ): # if mask is required need to tune both segmenIds fields + # attention_mask = torch.squeeze(attention_mask).to(torch.float32) + attention_mask = attention_mask.to(torch.float32) + q_segment_indexes = torch.ones( + batch_size, query.shape[2], device=query.device, dtype=torch.float32 + ) + assert ( + attention_mask.shape[1] == key.shape[2] + ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" + + assert ( + query.shape[2] % 128 == 0 + ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" + assert ( + key.shape[2] % 128 == 0 + ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" + + # run the TPU kernel implemented in jax with pallas + hidden_states_a = flash_attention( + q=query, + k=key, + v=value, + q_segment_ids=q_segment_indexes, + kv_segment_ids=attention_mask, + sm_scale=attn.scale, + ) + else: + hidden_states_a = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states_a = hidden_states_a.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states_a = hidden_states_a.to(query.dtype) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionSkip + ): + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( + 1.0 - skip_layer_mask + ) + elif ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionValues + ): + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) + + if attn.residual_connection: + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py new file mode 100644 index 000000000..14fc56d62 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py @@ -0,0 +1,129 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py +import math + +import numpy as np +import torch +from einops import rearrange +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) + grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) + grid = grid.reshape([3, 1, w, h, f]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = pos_embed.transpose(1, 0, 2, 3) + return rearrange(pos_embed, "h w f c -> (f h w) c") + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # use half of dimensions to encode grid_h + emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) + + emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos_shape = pos.shape + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = out.reshape([*pos_shape, -1])[0] + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) + return emb + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) + ) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py new file mode 100644 index 000000000..28ad834f3 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device + ): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer_pt.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer_pt.py new file mode 100644 index 000000000..75b7b5100 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer_pt.py @@ -0,0 +1,507 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union +import os +import json +import glob +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils import logging +from torch import nn +from safetensors import safe_open + + +from maxdiffusion.models.ltx_video.transformers_pytorch.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + TRANSFORMER_KEYS_RENAME_DICT, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel_PT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None, + positional_embedding_type: str = "rope", + positional_embedding_theta: Optional[float] = None, + positional_embedding_max_pos: Optional[List[int]] = None, + timestep_scale_multiplier: Optional[float] = None, + causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated + ): + super().__init__() + self.use_tpu_flash_attention = ( + use_tpu_flash_attention # FIXME: push config down to the attention modules + ) + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) + self.positional_embedding_type = positional_embedding_type + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.use_rope = self.positional_embedding_type == "rope" + self.timestep_scale_multiplier = timestep_scale_multiplier + + if self.positional_embedding_type == "absolute": + raise ValueError("Absolute positional embedding is no longer supported") + elif self.positional_embedding_type == "rope": + if positional_embedding_theta is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined" + ) + if positional_embedding_max_pos is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined" + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + adaptive_norm=adaptive_norm, + standardization_norm=standardization_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=self.use_rope, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + inner_dim, use_additional_conditions=False + ) + if adaptive_norm == "single_scale": + self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + + self.gradient_checkpointing = False + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") + self.use_tpu_flash_attention = True + # push config down to the attention modules + for block in self.transformer_blocks: + block.set_use_tpu_flash_attention() + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ): + if skip_block_list is None or len(skip_block_list) == 0: + return None + num_layers = len(self.transformer_blocks) + mask = torch.ones( + (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype + ) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dtype = torch.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + device = fractional_positions.device + if spacing == "exp": + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // 6, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = ( + pretrained_model_path + / "transformer" + / "diffusion_pytorch_model*.safetensors" + ) + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + with torch.device("meta"): + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, assign=True, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + with torch.device("meta"): + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict, assign=True) + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + indices_grid: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # for tpu attention offload 2d token masks are used. No need to transform. + if not self.use_tpu_flash_attention: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + hidden_states = self.patchify_proj(hidden_states) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + freqs_cis = self.precompute_freqs_cis(indices_grid) + + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + for block_idx, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + freqs_cis, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + ( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + skip_layer_mask=( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy=skip_layer_strategy, + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if not return_dict: + return (hidden_states,) + + return Transformer3DModelOutput(sample=hidden_states) \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py new file mode 100644 index 000000000..13c1ebf92 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -0,0 +1,256 @@ +import argparse +import json +from typing import Any, Dict, Optional + +import jax +import jax.numpy as jnp +from flax.training import train_state +import optax +import orbax.checkpoint as ocp +from safetensors.torch import load_file + +from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel +from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax + +from huggingface_hub import hf_hub_download +import os + +class Checkpointer: + """ + Checkpointer - to load and store JAX checkpoints + """ + + STATE_DICT_SHAPE_KEY = "shape" + STATE_DICT_DTYPE_KEY = "dtype" + TRAIN_STATE_FILE_NAME = "train_state" + + def __init__( + self, + checkpoint_dir: str, + use_zarr3: bool = False, + save_buffer_size: Optional[int] = None, + restore_buffer_size: Optional[int] = None, + ): + """ + Constructs the checkpointer object + """ + opts = ocp.CheckpointManagerOptions( + enable_async_checkpointing=True, + step_format_fixed_length=8, # to make the format of "00000000" + ) + self.use_zarr3 = use_zarr3 + self.save_buffer_size = save_buffer_size + self.restore_buffer_size = restore_buffer_size + registry = ocp.DefaultCheckpointHandlerRegistry() + self.train_state_handler = ocp.PyTreeCheckpointHandler( + save_concurrent_gb=save_buffer_size, + restore_concurrent_gb=restore_buffer_size, + use_zarr3=use_zarr3, + ) + registry.add( + self.TRAIN_STATE_FILE_NAME, + ocp.args.PyTreeSave, + self.train_state_handler, + ) + self.manager = ocp.CheckpointManager( + directory=checkpoint_dir, + options=opts, + handler_registry=registry, + ) + + @property + def save_buffer_size_bytes(self) -> Optional[int]: + if self.save_buffer_size is None: + return None + return self.save_buffer_size * 2**30 + + @staticmethod + def state_dict_to_structure_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts a state dict to a dictionary stating the shape and dtype of the state_dict elements. + With this, we can reconstruct the state_dict structure later on. + """ + return jax.tree_util.tree_map( + lambda t: { + Checkpointer.STATE_DICT_SHAPE_KEY: tuple(t.shape), + Checkpointer.STATE_DICT_DTYPE_KEY: t.dtype.name, + }, + state_dict, + is_leaf=lambda t: isinstance(t, jax.Array), + ) + + def save( + self, + step: int, + state: train_state.TrainState, + config: Dict[str, Any], + ): + """ + Saves the checkpoint asynchronously + + NOTE that state is going to be copied for this operation + + Args: + step (int): The step of the checkpoint + state (TrainStateWithEma): A trainstate containing both the parameters and the optimizer state + config (Dict[str, Any]): A dictionary containing the configuration of the model + """ + self.wait() + args = ocp.args.Composite( + train_state=ocp.args.PyTreeSave( + state, + ocdbt_target_data_file_size=self.save_buffer_size_bytes, + ), + config=ocp.args.JsonSave(config), + meta_params=ocp.args.JsonSave(self.state_dict_to_structure_dict(state.params)), + ) + self.manager.save( + step, + args=args, + ) + + def wait(self): + """ + Waits for the checkpoint save operation to complete + """ + self.manager.wait_until_finished() + + +""" +Convert Torch checkpoints to JAX. + +This script loads a Torch checkpoint (either regular or sharded), converts it to Jax weights, and saved it. +""" +def main(args): + """ + Convert a Torch checkpoint into JAX. + """ + + if args.output_step_num > 1: + print( + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ) + + print("Loading safetensors, flush = True") + weight_file = "ltxv-13b-0.9.7-dev.safetensors" + + #download from huggingface, otherwise load from local + if (args.local_ckpt_path is None): + print("Loading from HF", flush = True) + model_name = "Lightricks/LTX-Video" + local_file_path = hf_hub_download( + repo_id=model_name, + filename=weight_file, + local_dir=args.download_ckpt_path, + local_dir_use_symlinks=False, + ) + else: + base_dir = args.local_ckpt_path + local_file_path = os.path.join(base_dir, weight_file) + torch_state_dict = load_file(local_file_path) + + + + print("Initializing pytorch transformer..", flush=True) + transformer_config = json.loads(open(args.transformer_config_path, "r").read()) + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "ckpt_path"] + for key in ignored_keys: + if key in transformer_config: + del transformer_config[key] + + transformer = Transformer3DModel_PT.from_config(transformer_config) + + + print("Loading torch weights into transformer..", flush=True) + transformer.load_state_dict(torch_state_dict) + torch_state_dict = transformer.state_dict() + + print("Creating jax transformer with params..", flush=True) + transformer_config["use_tpu_flash_attention"] = True + in_channels = transformer_config["in_channels"] + del transformer_config["in_channels"] + jax_transformer3d = JaxTranformer3DModel(**transformer_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + example_inputs = {} + batch_size, num_tokens = 2, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, transformer_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + params_jax = jax_transformer3d.init(jax.random.PRNGKey(42), **example_inputs) + + + print("Converting torch params to jax..", flush=True) + params_jax = torch_statedict_to_jax(params_jax, torch_state_dict) + + print("Creating checkpointer and jax state for saving..", flush=True) + relative_ckpt_path = args.output_dir + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + tx = optax.adamw(learning_rate=1e-5) + with jax.default_device("cpu"): + state = train_state.TrainState( + step=args.output_step_num, + apply_fn=jax_transformer3d.apply, + params=params_jax, + tx=tx, + opt_state=tx.init(params_jax), + ) + with ocp.CheckpointManager(absolute_ckpt_path) as mngr: + mngr.save(args.output_step_num, args = ocp.args.StandardSave(state.params)) + print("Done.", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.") + parser.add_argument( + "--local_ckpt_path", + type=str, + required=False, + help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'", + ) + + parser.add_argument( + "--download_ckpt_path", + type=str, + required=False, + help="Location to download safetensors from huggingface", + ) + + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to save the checkpoint to. for example 'gs://lt-research-mm-europe-west4/jax_trainings/converted-from-torch'", + ) + parser.add_argument( + "--output_step_num", + default=1, + type=int, + required=False, + help=( + "The step number to assign to the output checkpoint. The result will be saved using this step value. " + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ), + ) + parser.add_argument( + "--transformer_config_path", + default="/opt/txt2img/txt2img/config/transformer3d/ltxv2B-v1.0.json", + type=str, + required=False, + help="Path to Transformer3D structure config to load the weights based on.", + ) + + args = parser.parse_args() + main(args) diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 000000000..9ee89d1df --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 000000000..c34df51ff --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py new file mode 100644 index 000000000..450cf9b19 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py @@ -0,0 +1,519 @@ +import re +from copy import copy +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union, Any + +import flax +import jax +import torch +import torch.utils._pytree as pytree +from flax.traverse_util import flatten_dict + + + +AnyTensor = Union[jax.Array, torch.Tensor] +StateDict = Dict[str, AnyTensor] + +ScanRepeatableCarryBlock = "ScanRepeatableCarryBlock" + +JaxParams = Dict[str, Union[Dict[str, jax.Array], jax.Array]] + +def unbox_logically_partioned(statedict: JaxParams) -> JaxParams: + return jax.tree_util.tree_map( + lambda t: t.unbox() if isinstance(t, flax.linen.spmd.LogicallyPartitioned) else t, + statedict, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + +def torch_tensor_to_jax_array(data: torch.Tensor) -> jax.Array: + match data.dtype: + case torch.bfloat16: + return jax.numpy.from_dlpack(data) + case _: + return jax.numpy.array(data) + + +def is_stack_or_tensor(param: Any) -> bool: + """ + Returns True if param is of type tensor or list/tuple of tensors (stack of tensors) + + Used for mapping utils + """ + return isinstance(param, (torch.Tensor, list, tuple)) + + +def convert_tensor_stack_to_tensor(param: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + """ + Converts a list of torch tensors to a single torch tensor. + Args: + param (Union[List[torch.Tensor], torch.Tensor]): The parameter to convert. + + Returns: + torch.Tensor: The converted tensor. + """ + if isinstance(param, list): + return torch.stack(param) + return param + + +@dataclass +class ConvertAction: + """ + Defines a set of actions to be done on a given parameter. + + The definition must be commutative, i.e. the order of the actions should not matter. + also we should strive for actions to be reversible (so the same action can be used for both directions). + """ + + transpose: Optional[Tuple[int, int]] = None + """ + If defined, transposes the tensor with the given indices. + Example: (1, 0) transposes a (at least 2D tensor) from (..., a, b) to (..., b, a). + """ + + rename: Optional[Dict[str, str]] = None + """ + If defined, renames the parameter according to the given mapping. + Example: {"torch": "weight", "jax": "kernel"} + * renames "torch.weight" to "jax.kernel" when converting from torch to jax. + * renames "jax.kernel" to "torch.weight" when converting from jax to torch. + """ + + split_by: Optional[str] = None + """ + If defined, splits the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.k1" assumes the parameter is a concatenation of multiple tensors (shaped: (n, ...)). + and splits them into individual tensors named as "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.n.k1". + """ + + group_by: Optional[str] = None + """ + If defined, groups the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.1.k1", "ScanRepeatableCarryBlock.2.k1" + will be grouped into a single tensor named "ScanRepeatableCarryBlock.k1" shaped (n, ...). + + *** Note: + this is kind of the reverse of split_by, only a different behavior. + it's easy to define "actions" that are reversible in base of context (jax->torch, torch->jax). + but it's very wrong to do so, since it blocks modular behavior and makes the code harder to maintain. + + """ + + jax_groups: Optional[List[str]] = None + """ + Generally used in group_by, this is a list of all possible keys that can be used to group the parameters. + This must be defined if group_by is defined. + + It's due to the un-reversibility nature of the group_by action. + """ + + def apply_transpose(self, mini_statedict: StateDict) -> StateDict: + """ + Applies the transpose action if defined + Args: + mini_statedict (StateDict): Local context of the state dict + + Returns: + StateDict: Output local context of the state dict + """ + + if self.transpose is None: + return mini_statedict + index0, index1 = self.transpose + return {param_name: param.swapaxes(index0, index1) for param_name, param in mini_statedict.items()} + + def apply_rename(self, mini_statedict: StateDict, delim: str) -> StateDict: + """ + Applies the rename action if defined + + Args: + mini_statedict (StateDict): Local context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + StateDict: Output local context of the state dict + """ + if self.rename is None: + return mini_statedict + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + parts = param_name.split(delim) + rename_source = "torch" if isinstance(param, torch.Tensor) else "jax" + rename_target = "jax" if isinstance(param, torch.Tensor) else "torch" + source_name = self.rename[rename_source] + dest_name = self.rename[rename_target] + if source_name == param_name: + new_param_name = dest_name + else: + # There is always ```self.rename[rename_source]``` in parts + index = parts.index(self.rename[rename_source]) + parts[index] = self.rename[rename_target] + new_param_name = delim.join(parts) + mini_statedict[new_param_name] = param + + return mini_statedict + + def apply_split_by(self, mini_statedict: StateDict, new_params: List, delim: str) -> Tuple[StateDict, List[str]]: + """ + Applies the split_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + delim (str): Output local context of the state dict + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.split_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + parts = param_name.split(delim) + indices = [i for i, p in enumerate(parts) if self.split_by in p] + if len(indices) != 1: + raise ValueError(f"Expected exactly one split_by in param_name: {param_name}") + index = indices[0] + params = mini_statedict.pop(param_name) + for i, param in enumerate(params): + new_parts = parts[:index] + [f"{i}"] + parts[index + 2 :] + new_param_name = delim.join(new_parts) + mini_statedict[new_param_name] = param + new_params.append(new_param_name) + + return mini_statedict, new_params + + def apply_group_by( + self, mini_statedict: StateDict, new_params: List, full_statedict: StateDict, delim: str + ) -> Tuple[StateDict, List[str]]: + """ + Applies the group_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.group_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + jax_keywords = extract_scan_keywords(param_name, self.jax_groups, delim) + block_index = re.findall(r"\.\d+\.", param_name)[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + new_param_name = f"{prefix}.{delim.join(jax_keywords)}.{suffix}" + + if new_param_name not in full_statedict: + full_statedict[new_param_name] = [param] + else: + full_statedict[new_param_name] = full_statedict[new_param_name] + [param] + + return mini_statedict, new_params + + def __call__( + self, + mini_statedict: StateDict, + new_params: List, + full_statedict: StateDict, + delim: str, + ) -> Tuple[StateDict, List[str]]: + """ + Given a state dict, applies the transformations defined in the ConvertAction. + + Args: + mini_statedict (StateDict): Local context of the state dict + new_params (List): new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Updated local state dict and list of new keys to add to the global state dict. + """ + mini_statedict = self.apply_transpose(mini_statedict) + mini_statedict = self.apply_rename(mini_statedict, delim) + mini_statedict, new_params = self.apply_split_by(mini_statedict, new_params, delim) + mini_statedict, new_params = self.apply_group_by(mini_statedict, new_params, full_statedict, delim) + return mini_statedict, new_params + + +def is_kernel_2d(param_name: str, param: AnyTensor) -> bool: + """ + Checks if the parameter is a 2D kernel (weight) or not. + usually applies to linear layers or convolutions. + Args: + param_name (str): Name of the parameter + param (AnyTensor): The parameter itself (could be either jax or torch Tensor) + + Returns: + bool: True if the parameter is a weight for linear/convolutional layer or not. + """ + expected_name = "weight" if isinstance(param, torch.Tensor) else "kernel" + return expected_name in param_name and param.ndim == 2 + + +def is_scan_repeatable(param_name: str, _) -> bool: + """ + Checks if the parameter is a scan repeatable carry block parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scan repeatable carry block parameter or not. + """ + return ScanRepeatableCarryBlock in param_name + + +def is_scale_shift_table(param_name: str, _) -> bool: + """ + Checks if the parameter is a scale shift table parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scale shift table parameter or not. + """ + return "scale_shift_table" in param_name + + +def is_affine_scale_param(param_name: str, parameter: AnyTensor, jax_flattened_keys: List[str]) -> bool: + """ + Checks if the parameter is an affine scale parameter. + + Args: + param_name (str): Parameter name + parameter (AnyTensor): The parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + + + Returns: + bool: True if the parameter is an affine scale parameter or not. + """ + if isinstance(parameter, torch.Tensor): + return "weight" in param_name and parameter.ndim == 1 and param_name not in jax_flattened_keys + else: + return "scale" in param_name and parameter.ndim == 1 + + +def extract_scan_keywords(param_name: str, jax_flattened_keys: List[str], delim: str) -> Optional[Tuple[str, str]]: + """ + Extracts the keywords from the scan repeatable carry block parameter (if exists) + + If the parameter is a scan repeatable carry block, it will return the keywords that are used to group the parameters. + otherwise it will return None. + + Args: + param_name (str): Name of the parameter + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + Optional[Tuple[str, str]]: Tuple of the keywords used to group the parameters (or None if it is not a scan repeatable carry block) + """ + block_indices = re.findall(r"\.\d+\.", param_name) + + if len(block_indices) == 0: + return None + block_indices = [block_indices[0]] + block_index = block_indices[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + for flat_key in jax_flattened_keys: + if flat_key.startswith(prefix) and flat_key.endswith(suffix): + mid_layer = flat_key[len(prefix) + 1 : -len(suffix) - 1] + mid_parts = mid_layer.split(delim) + if not any(ScanRepeatableCarryBlock in mid_part for mid_part in mid_parts): + continue + return mid_parts + + return None + + +def should_be_scan_repeatable(param_name: str, param: AnyTensor, jax_flattened_keys: List[str], delim: str) -> bool: + """ + Checks if the parameter should be a scan repeatable carry block or not. + Args: + param_name (str): The name of the parameter + param (AnyTensor): the Parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + bool: True if the paramter should be treated scan repeatable block parameter. + """ + if not isinstance(param, torch.Tensor): + return False + + keywords = extract_scan_keywords(param_name, jax_flattened_keys, delim) + return keywords is not None + + +def jax_statedict_to_torch( + jax_params: JaxParams, rulebook: Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]] = None +) -> Dict[str, torch.Tensor]: + """ + Converts a JAX state dict to a torch state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + rulebook (Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]], optional): Defines a rulebook stating how to convert state dict from jax to torch. + Defaults to None. + + + Returns: + Dict[str, torch.Tensor]: The converted state dict in torch format (Pytorch state dict). + """ + + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=[]) + + if rulebook is None: + rulebook = { + is_scan_repeatable: ConvertAction(split_by=ScanRepeatableCarryBlock), + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + } + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + jax_params = copy(jax_params["params"]) # Non reference copy + jax_params = unbox_logically_partioned(jax_params) + + delim = "." + # Move to flattened dict to match torch state dict convention + flattened_params = flatten_dict(jax_params, sep=delim) + + param_names = list(flattened_params.keys()) + for param_name in param_names: + param = flattened_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, flattened_params, delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + flattened_params.update(mini_statedict) + param_names.extend(new_params) + + flattened_params = pytree.tree_map(convert_tensor_stack_to_tensor, flattened_params, is_leaf=is_stack_or_tensor) + + to_cpu = pytree.tree_map(lambda t: jax.device_put(t, jax.devices("cpu")[0]), flattened_params) + to_torch = pytree.tree_map(torch.from_dlpack, to_cpu) + return to_torch + + +def torch_statedict_to_jax( + jax_params: JaxParams, + torch_params: Dict[str, torch.Tensor], +) -> JaxParams: + """ + Converts a torch state dict to a JAX state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + torch_params (Dict[str, torch.Tensor]): The current params in torch format, to load parameters from. + + Returns: + JaxParams: The state dict in JAX format. + """ + with jax.default_device("cpu"): + jax_params = copy(jax_params) + jax_params = unbox_logically_partioned(jax_params) + torch_params = copy(torch_params) + + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + delim = "." + flattened_keys = list(flatten_dict(jax_params["params"], sep=".").keys()) + scan_repeatable_cond = partial(should_be_scan_repeatable, jax_flattened_keys=flattened_keys, delim=delim) + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=flattened_keys) + + rulebook = { + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + scan_repeatable_cond: ConvertAction(group_by=ScanRepeatableCarryBlock, jax_groups=flattened_keys), + } + + # First pass - Rulebook + param_names = list(torch_params.keys()) + for param_name in param_names: + param = torch_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, torch_params, delim=delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + torch_params.update(mini_statedict) + param_names.extend(new_params) + + # Ensures any list of tensors are converted to a single tensor + # This is due to the fact that the scan repeatable block is a list of tensors + torch_params = pytree.tree_map(convert_tensor_stack_to_tensor, torch_params, is_leaf=is_stack_or_tensor) + + to_jax: Dict = pytree.tree_map(torch_tensor_to_jax_array, torch_params) + + def nested_insert(param_name: str, param: torch.Tensor, nested_dict: Dict): + """ + Inserts a parameter into a nested dictionary. (to fit Jax format) + The keys in torch are split into groups by a delimiter of choice (usually "." to fit torch schema) + and then inserted into a nested dictionary. + + in case the parameter is of the form of "a.b" and "a.b" is a layer type in jax - + the parameter will be inserted as "a.b": {...: param}. this ensures compatibility between jax layers and torch layers. + + Args: + param_name (str): Parameter name + param (torch.Tensor): Parameter itself + nested_dict (Dict): Current nested dict state + """ + if delim not in param_name: + nested_dict[param_name] = param + return + + parts = param_name.split(delim) + if len(parts) == 1: + return nested_insert(parts[0], param, nested_dict) + else: + key = parts[0] + # May be either complex key or nested key + if len(parts) > 2 and re.fullmatch(r"\d+", parts[1]) is not None: + key = delim.join(parts[:2]) + new_param_name = delim.join(parts[2:]) + else: + new_param_name = delim.join(parts[1:]) + new_nested_dict = nested_dict.get(key, {}) + nested_dict[key] = new_nested_dict + return nested_insert(new_param_name, param, new_nested_dict) + + params = {} + for param_name, param in to_jax.items(): + nested_insert(param_name, param, params) + + # Jax state dict is usually held as dict containings "parmas" keys which contains + # dict of dict containing all the params + return {"params": params} diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 770093859..10414eabd 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,4 +1,5 @@ { + "ckpt_path": "/mnt/disks/diffusionproj/jax_weights", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, @@ -20,5 +21,6 @@ "positional_embedding_type": "rope", "positional_embedding_theta": 10000.0, "positional_embedding_max_pos": [20, 2048, 2048], - "timestep_scale_multiplier": 1000 + "timestep_scale_multiplier": 1000, + "in_channels": 128 } diff --git a/src/maxdiffusion/tests/ltx_video_transformer_test.py b/src/maxdiffusion/tests/ltx_video_transformer_test.py new file mode 100644 index 000000000..760eb4b2c --- /dev/null +++ b/src/maxdiffusion/tests/ltx_video_transformer_test.py @@ -0,0 +1,306 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import jax +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from flax import nnx +from jax.sharding import Mesh + +from .. import pyconfig +from ..max_utils import ( + create_device_mesh, + get_flash_block_sizes +) +from ..models.ltx_video.transformers.transformer3d import ( + reqsCisPrecomputer, Transformer3DModel +) +from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection +from ..models.normalization_flax import FP32LayerNorm +from ..models.ltx_video.transformers.attention import Attention, BasicTransformerBlock, apply_rotary_emb + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + +class LTXVideoTransformerTest(unittest.TestCase): + def setUp(self): + LTXVideoTransformerTest.dummy_data = {} + + def test_rotary_pos_embed(self): + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed( + attention_head_dim=128, + patch_size=[1, 2, 2], + max_seq_len=1024 + ) + dummy_output = wan_rot_embed(dummy_hidden_states) + assert dummy_output.shape == (1, 1, 75600, 64) + + def test_nnx_pixart_alpha_text_projection(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + dummy_caption = jnp.ones((1, 512, 4096)) + layer = NNXPixArtAlphaTextProjection( + rngs=rngs, + in_features=4096, + hidden_size=5120 + ) + dummy_output = layer(dummy_caption) + dummy_output.shape == (1, 512, 5120) + + def test_nnx_timestep_embedding(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + dummy_sample = jnp.ones((1, 256)) + layer = NNXTimestepEmbedding( + rngs=rngs, + in_channels=256, + time_embed_dim=5120 + ) + dummy_output = layer(dummy_sample) + assert dummy_output.shape == (1, 5120) + + def test_fp32_layer_norm(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size = 1 + dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) + # expected same output shape with same dtype + layer = FP32LayerNorm( + rngs=rngs, + dim=5120, + eps=1e-6, + elementwise_affine=False + ) + dummy_output = layer(dummy_hidden_states) + assert dummy_output.shape == dummy_hidden_states.shape + + def test_wan_time_text_embedding(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + batch_size = 1 + dim=5120 + time_freq_dim=256 + time_proj_dim=30720 + text_embed_dim=4096 + layer = WanTimeTextImageEmbedding( + rngs=rngs, + dim=dim, + time_freq_dim=time_freq_dim, + time_proj_dim=time_proj_dim, + text_embed_dim=text_embed_dim + ) + + dummy_timestep = jnp.ones(batch_size) + + encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) + dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(dummy_timestep, dummy_encoder_hidden_states) + assert temb.shape == (batch_size, dim) + assert timestep_proj.shape == (batch_size, time_proj_dim) + assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + + def test_wan_block(self): + key = jax.random.key(0) + rngs = nnx.Rngs(key) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + + dim=5120 + ffn_dim=13824 + num_heads=40 + qk_norm="rms_norm_across_heads" + cross_attn_norm=True + eps=1e-6 + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_dim = 75600 + + # for rotary post embed. + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + + wan_rot_embed = WanRotaryPosEmbed( + attention_head_dim=128, + patch_size=[1, 2, 2], + max_seq_len=1024 + ) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + assert dummy_rotary_emb.shape == (batch_size, 1, hidden_dim, 64) + + # for transformer block + dummy_hidden_states = jnp.ones((batch_size, hidden_dim, dim)) + + dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) + + dummy_temb = jnp.ones((batch_size, 6, dim)) + + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes + ) + + dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) + assert dummy_output.shape == dummy_hidden_states.shape + + + + def test_wan_attention(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed( + attention_head_dim=128, + patch_size=[1, 2, 2], + max_seq_len=1024 + ) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + dummy_hidden_states_shape = (batch_size, 75600, query_dim) + + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape + + # dot product + try: + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + except NotImplementedError as e: + pass + + def test_wan_model(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True + ) + config = pyconfig.config + + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + wan_model = WanModel( + rngs=rngs, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + dummy_timestep = jnp.ones((batch_size)) + dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) + + dummy_output = wan_model( + hidden_states=dummy_hidden_states, + timestep=dummy_timestep, + encoder_hidden_states=dummy_encoder_hidden_states + ) + assert dummy_output.shape == hidden_states_shape + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/test_lightning.png b/test_lightning.png deleted file mode 100644 index 36e844cc573bbb6e3f89cf1dfc77ec19cb4f2f09..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1343016 zcmV)DK*7I>P)8%97jY(qN>|97c9-o&g}pH!_vO&5>3zDy)=!x3uYyV2zOIu-UkLE z<`LO_jWWaC&D2x{FgOf_Z{PeDgxL->0wl}~06+tX@rri=!tCl3|B&=+c>PTSNWuvC zjY$|l5dIG1zBD6YMlwE)PyCO%BLM@azsBQ12x!Lml8L|T&%3%y{z@?Bi+rYy@l8)O zp9Db|VK4$PG6!j9{^m4X^JmxniF_8X|Frp8Fq#1}W<{GeK!Cib`C7*aU8oUY^1T^Mq6))$gd>5QM2OkZ7_9Y?Bl0B~nnQINQ2-a$mCi3>ZqRSI; z`Exx(Qc5C`fGHssa-*K>9Bp(~cSZmx0vn$TC4?Tq5O=I!08ogcqFm|tHyAXiHRM2?#paio(`w2h^N;-QLDU-S44ro0DWA}ipgPlp z#djxawUmi;+9)~Z^0jDWrmIP0PHgen`(JqF1acvLbb``VHdfk+V-Wk*c1=T?nSx1E zCI*m0ZXoY{^_BQrg0D?|&vSzfSM$mhxc<#1eL7Jf>p7uXT)Ga;LpZaMFcZ3Xq%jqk z#8gUO4C>59@)?K_C}KAWg_U{)A(Np7bM7-3)C{R{f~cEKa$}-=QlKbGvl8jKy016I z#~MVzRIMoyV5u{OFxAp~0EAG3_@{Yy7g?c?EEQ=HN9|)Qfu*pe!DRZ(#*kSY2SUvV z>^mUho2_OPSm?gN*Qzj3UG4ajOq`(&P?!OxP$a{efu>rEmi+F_?^@PU8;x)2&{_Uu zV?MzpLryB?izR8KVy+bp$sG3?ZyHPtmgEIOYH^YZY|<{AFiYLvrr%0Dm)yG2T3;ZO z@yen7n&cQmk7XY?@A{H@;+LPrU%Bpcef`O5Gxdd^Ty(>rEETUl(;JmsrjQeY5+Lo9`v4bqF}mlDW5va6xe zf`LM#4(XBzZB`yu5}ea>&87!k@6q&|vqrA^j6cbIkj~e~4aT$K@bPK}zHXL}m%8Z{ zorJ(M1FgXN)k?|LpT-~_<{EN-aH*-+>uA)7O3rlDCY_Hgou@W{@niXZ*xTj0txV8F% zTW!OjJS%N&mW-9!eh?3@ER9(xvgB@Q4AgUcxbp<2>0D9672RTHnOEdGNz^O4n?z5E zjaM}Pt8((opAx0ZWRu^Gf5veZcAs66v<1%f`l32BIRtB_x z!u*kVDP_|j&-@EH3{t@9DI^mlrPPVrC6uc1BcZ<(N#}%s!B3$+CvrA(-ZQ!~CrgxB zpuUpi1hlknCRU3xfjW07WAZZoZ4;Aq z7b2wYl_8hwQ>;2f5K=?DnrX=xUcc%4N&;f;w(%Pcx^urwB#k|T_bs@CH1ZVy8>mP1@|_!}6^ku2iK*v7JJ>U~QO*tEg9fJuTgJeYg~ibV6d zM4Ia(lb0+OePZ-!d96MgyqPyi5fl)8ayNuGFFaYkzXEqvh9rvFrRtBBE%K`G>X@*_ zouNfDmK10C?;sr1hxOf5KsLXLFZWMi^*c-^Q~q^O(|q!>t5?V zn{HLs`D-WjJPMyl!GKBSIqz^{vg_AmeI)v|WP%bNvxxfQD8{3%Z~oaY$X%Fq*!4F^ zly{{jHe5EdPtKWnj6CnjN#le?36dSAM@jb0MZCHHW216g|r`h zI6{gzK98N1^nWx8XoBIyYou`K4sq}z#Gc<}Tro7L(V&j?htf?yjK!AfzN-OZ5D}27 zCJ?nP^={a>7P>@I!p6Ef9FfVP?7Ahb(Xt6$g|OZ)x^jgl@C9^#p&;~SuGm27fPsQ^ za>H4pG&B-2mV7LZtu2kaB*si4upvH*Y(Mi0enDyt6l4i7IG*d|?SJ^jo4va2ufDE@ zpWjz)5U11HYy+lvh&N(6YA;~vdn8n2#b#8)l(5rEknkW=T8vQNF=&vJ4PNu-kZtfW z6Av(faC)6ZSi9h|C#lnrpg|}IHl$`~Q790C)<}ZBCgMCHG&6ComU<{mmA)Db7ln(v z@->cyi_wTd2UG!I5sy$)F#%j!ggnsnD}ykhs%3BsE)E)@=8ju}X1fIzbERSmK$=(* zt~1xX4De*wn1Qngi2`pA{Ob4A;t)Th2yiJ z*DJunl;AqU=~93g$|=o^98l5+gVH3uyk-^@HC_cI9lyH&GiY^2Au%LJ)CkcI{Zw2*pq4L||Z zxmw}xHDIshqU$CQ8eVHSw~YNVb$3!NI6;M~S?nSu3;~+y03W_?g0i?FvDInmh_8vF z2x9CrGPZ3MP4~;3tR=;`6NS0`10dz+Sk(BFnF)Y|2AcDv!jZKrm@@o)oYN>njhIJH zhGha=3Dqk)pMh@7p5jKn7Jt@Shha~81OOcGVeA2rSZIoa1MraMq{l+<#GTD5Dm&OS zSIJT~hdw6(2ifai_#=3|w|W*A)f=5;ocK&@d2W`u3+bYd#Votq*7tUXK$4VGg7 zT2jI%?%~1|sFP1K={BF5*hNz7DV=HS-VDJn498n)5zkhjxX?0lG~?szO$eifg!TAu zGoerHH2ycLRCrZiWW~+hh83Nl-AWWUHFl`(|XqwC~8)5FhbQRggV!yj8@as zha#c?O_4T(o;EX{>jDe+i? zGm(#8Va=~s3M*;KIRpkJMNLUVBX+7)-dl(P_pzFwHqc;7qYF3+OcQ^*x~F>(EO*E4 zaF|3IkOxXw)->?LSu}1zx+KdTe9FC>TjRBu^C3YnLn1S}nky=b{0F=9!&^-B4GD1Ux zGQun|Bge;Pnh>KTiPBe^EVc;sXy%eJY)IeK6gN!SCcXitofM&(I;9b!rmZD{OXH-0 z^z4@)zCbrk1WlbTe3l}f_aYDlxML6G>#J@OL{zz&0MW4Th$Nz5rBDPzj|aKn5fRmhOo4811WGWZjbUk1qMTU8yUwSAGDimpvXn33PA@WZkt<<-hrb@?Xy4;- z`CwBYMZQwia1c?!xP;?A5Xn(EX?T8MtjU0TMTWyqMWY53(MbZQYDKa_B8wp`%@mcc zU6fqj%y}B|vliGgaYZK1iz@>Kzv8agW2W?}2!(`#Me5)fKbW~swxq?bp9T%HlnUfPO=It`BTLos#0NkOPQVo35Jk!#T<(= zJYh(~&e!r=&E0rr;toCr@=R9=gmh`-@nUovE(rR-FLX#1oOx;rS2;^p8@dD#HNkWzH zPEx1_Ge=Xdeluj}%1J_GGpUiP8n_|JW)iA8)YAYHsrzv~F3SOPj5CSW1-Sy**AfzKH1oLS9KrwMZw2N>dAZ9`g+_QRqbh7|pN?dT>LSZYlFP zD{EhGHLBUlDg>J5MF!>skpTn~m<&x38Eyqh-JJx1mkB8LCZcK$+Gu2%+K_~2ED4nb z1Wsxq5J-sqynEm}Qd6<`Vo--)Q{7ol?52?vXQenaRB!4w$KvQwOXV>hW_u}TPU}gt-2u13$^#2v#Aaen>Vm}Gc%;{ZA|WvaEarQTSJ8Y@31jpCGlG%{;5D9nBYn8;Tb-0C zLv>#%#fdas5AHN6Ul39u8P#Bo&6y#GKN8zf)N8`xWwnEjf12fPM?kf zU<=&>vnB*Wl!-X=S);g`2-Y9~Z4?nv+Zg_kSRAKy2^4HLT=E1mj;DK*2GCR373u+n za#+Z;*^Igqme@(Jlbeyg{^4jx+MH%0J{JH`hC&?lkje}2`&($+A$So+cKE-`XXlj! z0f+;?_^v7ZjS9E=4O|b45-=4|H5~6`W-n+sm{9U9<$ER28$6|7itOD%5K+o0<_tkX z`H{ND5TFWm(T0Sqn(!`2$bt|fPa?FV8B93PPT=Gs`}6_j(s?-{cckhbIwGGjc(157 z*PrPYNMeRTec#>eD?bh+s8Ur%L}r=>n#&J{iXYDVpei#>sF;*h!*Xv$fVeSY<|Bs` zO9VU5&x5hzxf!{#`H;LpzEXd?dI))wX%4iJJ(;?1lI_UgQd(|SX+#cP4AyhYb6v_% zx;7jWhsOULzU|tg++!;do~C>ZFc$;m0&=K|TI9sLKBl2yrm47d%pXpQWGYKVu8fhL zsCmOAtFghgRIwq0s>0ne?PD0$V<2f)`Ln1=ny7uKIhc9C*SWVSFKHu9L2I1SOn;Iw%SdYFY0hfWXw*-AIvebJBgBLQt#)H^DXbS61c2h{?I(UXsL0orX zNf^2YQ1>x+%V@FX);WEO`1|k(Kqd>Akm^|w6g8D7?glZO+`)<=VG(k_oa;$2Bn=Fn zWAIcd2n$RJHQ!knl#s+5MY4gwuHIv0pKDkt;gDqRM96v=%F84v&I54D&|XfcOG5W)^R8;ooXq+Xbz+2)Ypu^D$1`?j=K0Xo|1g-c))=%0$qh{b&+O?ZA!(` zj7q|EH$$VlRHt=h_&a@#BoZ!^0b1PnaCfOQ4&jOS0&^AvK571PgGA$QnrH#kp#}nFq3AU^DwxBPhFDIRx(dFsG9?1W{Zd^z z%7tV~27$~h3BNH!4Le-Os}zaH3G1Sa7>d9k(jI>6i$f6O= zY>yNmhGUreR1AqDQW|9Qz4r8@L-PfHlG4>4iLqMjVukaojUZ5hfKzO5+q<*GA8Llh~ z4|7EEq;8Q^Z$tzsQ|C)KysX6FSmDXlFc?3cJ7x9FntKuu>CI%p#aPCkjC_@L5LZlj zk!1nbY<;h4;yK76T?3sJFXWE;U%m??l4?aXOty^e$YqQvGj`2QPltb@b^w%T(FIiu zv2x8=DkODY+yo)4$`lyOofKfP@m2<|!39)qR%FuG4!F_|+RFYRQFK3cCq|MPVNBA` z_a{avBR+JcyKw^mBl3LU8VHDNV;6!JPDcaOBsD0hU4*VA$WBKr)|;4`EAr-zmBP)b z0&9XMn!+Gb7t?rTWal`cHF}N$U=*mT11tBCHvx$5mH%?#VrK33Nj1*bHr8P{W}VGsJFUO7BM=lYF@ER~?IA)h?PLq`313?Rbu3PQvn zSuC4_52Y23XN%~ZRJP0miXJ>>?BepzZP)PP8X9FcHPEO+QJ8r)p162WF}27? zLPJxCV$WSRt1%b=iV4DU842idu|@l}q2_(5HIPiwFM;M6coeG68Q{4sMZ=tha4Cbl z&}PK&ihdFt00S;o+oVa6mzH6wYhFKb{F0jA#o7!MFfmA~dJl^yZEGkOi)e_F;!w%r zI}$?P05zKA)?SMZ^JY`Dh&l&Fz$g^0TBM9Nx*N>n!!4V_`TWF_iwp+jrY3`N-UEZ} z0_nX2+DN~+QFh%3h(Ku4n}1he(waDzTh)od$7lv;gDn)0>m?Y<5lvNhO`(fyW@^wZ z6b}4rFoUWGtf7WYYz#gGb$=Da)Ab2fAD2R9opFQ0wyVnEo7ZX&4WVK1kE6Bi6+7!Vg3l4 z$^qdfia#>^nCfEyLoo>*U*aR=pUw}~2rd=x zg^$M)^`_T#QeZZQ66p|`4HbJhgn#VFP{Oklp_bhe*A3;G51O7`U~FtuJR%}9wd@>V zF3l_}jevB_F#jLV%t;~lEZn1G_Vb>NT_j6N)r&z;9{B2$IP_b%KVtyTX|nYzaXp}E zpkw*`TFK3x1@8d4tT1YDNRcN@J^TBDhYP`uNjp*O;;Ibvz$B8`U&fi7l@ z?0gfCOv9*}rgLaIL|X=ms%>h3iTf`|g1i6eT1;K%G2|y*VnIzULK=a|0yVP24lJJW zYdmzZ`woXM-qkT}*U0<;65n_+T;ZKTkWoenz!RZkr!~odgCoeY5jaXzk@xWe33p%( zR>34Y1{UHljSX)3N%9ZVt7L;5gKCgy1Cc9&>_KNGl2Ry^6ib;JzYR$eYdn-F2L1&A zN{T}(N=#`h?wr=(?G$zc(D>(3MpjxR4F-RCd>|@pky4{H8i}Y;L<_$@5<8UUr2#5z zERN41mR}T19nbNf^AG8DFU>iKK%*#>&L|+@s4LP`x}ZgR23mMbvp7I0VIhkHdw-Pb z*f$FaXKE0W-ISnp$I|dD7ZIv8hqwx;=jteTPnLD1?YZCK2^8}pVv(#e7MkZ5!IX)i z;b>C5nX1;lM3jILaT3X>p@xfuF!l8dfeGgVcQ=;N%~GJTfu*%3{lWs`pM5mzZA$`0 z?3Wu@Ow3~}2P3zvlX7` z8(`4gAQZ7^5Un<35Td|(7Jju?Ois+bC`z*M?F9v>qzaJXou!+jUS$9aMQ3TlIF_BAK~kx<;%g)O(%H(=vF-oVK;t1!$xPXtTr;`2vsGC=cNnEtJjwKS z7K0|rub1VhQ1^nbP%GB4=q`(Ly7DT`vB#{ZJlh15?*=_yw-~os>7r`Nwzvc_O7^HbHKs!59Wg&uvE`nrF zDmJ4@_f-Fd&CFf$3Zs&nx-ZkDvn!#XA>shiKn*0Y)64_JCekeYJ{3HIcTpCD<9M}n zJXl28ysWZ$PX&q^8oQc;Vx|(}M@bDyJh~imRBtD?$>caUwwQiUCG7bqys& zhZvrlp}6SvX_T-~n7620nFJ_IPg5ZZ(u>|~upcJ* zWj|i(qDmDgv%d;6ssXQQbsgfa67}XV5g8s)rVIuGbVdtKNwDp%$s%BHHIvlzN{N*v zmst?SD4~d%o}IkNFHy6Um4lorhu3RwTy$nVT`qxx!DOhqqZM%RKN154oOJ$i9LXW? z;a}G|)g}mu7-4LPd*#TVH!?(gh?sMWDUeap}ED7h4;a)_a%MwS^5}vaZge1dIL|UjKWd(rnzzqdA zu;V02LYBB1$mN?<+_>A77WIWl_%;;L)m-BLEM>RrkMhTWD{-?NZ0>OR4vL0qdn_Y4 zund7%D*SQR0TJro@(5Ti`!p)T&~q#pz+(bzVkz1+I`u(x4%qB#F;PFTt9d52Gsp7<-MG^>YH(2J{*cfizdX>DilC{HG8<#|E!CizE zhjq3~w8b&8QQ&1vy&fwTO0~2Y23OXa*ywDX#zV{~Z5;PieR1Ti&*?7OPkuH4N45z}1( zrU@Y2boXgPqitou{SbbULJ-kBRw5#6n6VCDAf?WI}o-5ME{M*$l zm6)X_E|n4LJ+`{ELQ$_lBVBAVLh>o8+3Wfg%=LCs?C^<3%15O6n7Mfqs%^vl960z*FRF@6j4vuh-M!L70T*D+ADvf|aC8UREuq0Noyk84Ga_|~KCxBuEdJ8*S zghDo9qroSZ8<ljkQ6DDjyM0)VQac`zTQ-fbx0VR?yzWrDAP zD(m<3prX=rn9UXo5?5F@>6%bK3pwtb%ay`qh|&=XLh1M~jQ8`w`_Hs( z{qICp8Q$<^4mZZ@XN+r$-e5M(C!IaM@inKk8Cc#Sz$$%QVsqV1AE^ZfFmsDy#7LXb zGQkGZLgXY%C5DouWI?>4M_Rn7?MWq$gpzV=2Z{7UFvN|^GFQUsp-rALM(ZBQf&^=6 z0Qi!MsxUK>$wruJGi%~<--Oyl`qD(KH|atFI<+_Ly>-~S$SSt3tu-Ex?YQuCI`zIB zPs_T<;n1bGzAnP31q7Pyo0(oVeYx1>Wq-^EL(H zWEhVid_uGmgMfB zJ-|^G7;3ao(OvUQ55gOO4lSBgr8H^H^Vw;vOrn{i;Dz{u3d@olkho|xaR%Tu`kvoD zDV~gx4br)SDTx5g5(o{UUe^Fo+@dV;9L__VzrhF;xD^nmSO^(;5z)nA8cVt4W0=7z zp0aKs5psDVdUb_2Ld>`vNQX8C2aH}0Ok)OVDvPi{*zhh@EiWLA$RZ>GJFR)%4~>-8 zrIU@eHXKBb4d=7mAN$R@-JI6Pr{mqB-#?vBH@Bzz)9H9P+??8>wPg|MY|-`Fn^J6- z%?!5hdU@7;yF71~kC)4bj~_mK`sK%$zy9^(k3T;D@Z-x5KVN?MY5RDwZPUwU=B;u> z(CN(@I%L~72WL%Whc()!qQIu^$Trt@Om|MAB_iuB0@lI6XlOwO1%-))NV9g{SJm08 zQuovZ(oUJawT;|$55bV32ODvqx&_W{jNv^RhR9x_N930aV^^y3|v|B!?E#LpszZ zGiK;l%b@s|*t>o*jyaxgI^_8qE)-P4xs0QB%#;Fm!c>ld$s9-P{Pg(Z67P(Rc_H381Oimd$LAkASG#T zHDo@wK6C$>zS}`45oP08UMtF&Pn5AlBS;=+93*6=4sxdM&321Jq-8Y}%#d9Kl@yBR zhDvO!S0co6MU!$Lf`Fv86h)vk5vsN(UjDlXg|_tG#RR%8(i=Lk66?~Iz8nrb9oyZl z+}|EF1yS`r$8s`S9akKm7H_Pe1+i^5I!OT%^getM&z|s%E-2HR|pO4AF#i zqxOg_M?E1uGa-E4)*Cd=4FHKfv|y~4@3z5$CR2zKu&7W~4wKY9H0q2uf&7%2;b#dV z4^Me?3$YX#-rL84uL8(wkO}=KD->r0DV78oOsWP(=Qv?>lur>16o{VjQHm*_e9&Ab zLaJ+qXi4E>z)EDt^zuNN8pV)2$rT~PG!iXQ^Ob_}A^9dR1gEFWoCu<7UlBn82w`I7 zGh&uKNbo+Hl{vnG=}`#?=6Id3Ch3K;C>-|%Y(fw%JoO}07%9O*&re{S24E^)Kq(@` z+GJfsU~5BfICfb@4#(xNw8M(Sx~{7ndh1K;U6$V6lNyJHH7U=LQsKw!d6-+iX+a6`QK6?wjr#6so)J8+N5qcExT`)yt;aF1zZkwu`7jciW;hW;VJf zOuem0Ml}U=Uhv&fN3W5H$DnnPX)ywTXq{geKEJ@RssB?gmkM#d6Hi4@bMM`(z6W;# z$U`Ku>w>C9!rBzgN73!{7J+oIrSQL2CrJPrwW_9BsDRW%3EN9~hWkc5IUeqbW@gk4 z-33QkK4eCy(N{(($f!>)kGESu!ma@(9N8$+*F!;C1b}6m!$cSi^+*QfjdC?tX>RB? zk6!_#NwX$s#s+Hwn065nti(z?A9;IP?oa*cwmrRF9`A1M9!?LB_cu>(jyLD^_P!s_ z?QkP~Veb%97Kj)LI_M=}t7mDM+J1rUy6@OkFTd#JL>3{qW~M|NNI9UjF()KfF-UcZD@7-30BntC|WxgCC-|n+gcb9N@TW@f4fcr*%^D{sxC+~USA)23nN4n&4+;c7|f=B zuQ@3!fy9!T8dQ2Do}5k@Mm}mxR6=eH068OD))!yDEHv8qv1ZMP;za}g}mqM|j5h0&^<;ryb7YeO;MIdE*uJLN2 zhV_*rN!qF!fHCR`Nz^_iet6|~`SQ!If63*!3XM23)bBQyMiN0JMsRz_%9gXrE(q&q z0^B1e;&4I2jCTcaHToo(VL7l4ez8O(OyyBaHmR0W zbo??!&)m;+ocWhBBOP>yWq>_XMt^ZK&v@|@UnHEU(c)96OJ{EolcjNOSX=Llh*%fv zT^6y^($1^i-1OV?^6=*L_;mmD{^92F{^sdvy}e)Vp89&~hcnuZ$q`;SEFIQ{0EtFN zyX_mzU>kYCez9%iNoy!k;ts2ffAnBeb7|aph3R@-UQgcei_jgN(Pj z`br*QvBEVc7DvUAiu)U>gzSx@x)iEBsqj7 zOc7Yi654y|Ix7J{+RzeeVLDR+ks5~~SD16-Qzjb07!#}TMo)SW^vbS0$0 z{nx$8@woPeg}RC?g2Q5`liaM!X_fP7Js+0iu^ra+cseevCRvYV>zyKY-p1=|J%E*sV+ zmkpO4YS>gag-R4!isGVqs4&93F|({2_5hC2yu+jOnKdsL3NrF4d`s08wAIA@^ln#& zOdR^D$YWNDWSHzJ-T@$Ux{!e-W94rQN^FLdKerz{SvK+W9kOl0z^}{Qi$W{^^e&{`%ANPapY<;{4YOrM2*kHIq zoOWgam}}=wEOcm!>XJ zkv5QUw@!2)FbfsxPnCjMdC%~_(NAhLF))uQ^^ApFTr^4}8X?4?vC)<$$BwmeZMLq~ zJJ;S0s~kISZ`G7-EH}Buvef{?K%{Rx#Z~N(?AJ5Vc+K-eK(tt)bS1q6b4cL7r z#X7J@-yL%_xx>64_h$CI+Q#j2*gl;vKR;bwzWe2;=b!%a;fFu}@gINxumA9efBKgn z{`7A@{`|q7U)mzg!@Zrl?W!h?+)Myb=uR_h; zJ9f0ZYh`j7^3-temyr9P3Lc&tq5#v2hjQ;(sGAZEPE8Z?wkQs^S%4<+F0qpiZYOcg zQ$NR*uU<190+t;=2Co3_@QpXb3)4rn3nJi{F;5P2*33$Pms`S9K1!yg9!b5F5rEG} zT5qNV7%5!|K7HdFBR`1&_(hpvjhR7HpF%itq6gwf^g)NANi46Blt!p`?G!*7x}b{) zO(^c~?_GKqAv)!_im0AMPDi=BUGE?I&3U~!9qw+{(`h*#kB7A{hu)X9wS)A|<{h;p zdQ7J#auKAuSK?`8`FLW_ZZF%mZ~M#3%k%U8>G|c;r|t55`Sj`HR>8LGw(HA_Zkuk~ zzNu}ST?D(b!@6)cZi-E@@79#Nim18MUX5F-;(a$ENKqIY8f4ovQE_`4jC{t?;7k_hA20~qroB!Li8!3^1q ze1ciDj*adNPLVoIuB=!O9KkrMe4@5~BI)Tj37>PVgLh@-q*eu{8o@&ttCNyH*UVH% zM*pB{1r%z-N17>Vo-ay^^Bie$Dg|f|TK8K3B52XT3Z2r3waKCLxU^%J+oQaB>R&yb zzWe&_{nz&oUp?KueYZZm#o@jk&X6Pe(WIe~Y<@qA!Nf54hI1yR#o4ERf~>Ekk~@~& zc9r91`rPhb?B)6J_QT=R&$k~xzWwQ!?|=O1r$7JYhd=!Ir$7Dt`yW33@R#SGpZV$8 zcJk6>7qdp)Y*UCpHje}A<_>aq#KZ=Ma*vf!Ew4oGN;(a_E+z<`^lF>BEd4|Y2GVd3 z4mQUaloWk5HjE?@A453rn-@30DDXotUGw-o!@-yIT0ep>U}*^mxWg=33+Ppz1puaB zD%{M@FnK{>oDlDsXP1+gfx!ls3@c@VloL5#zYGbXV!?H~SE$`#Wam^T^k01U#oMFw z(d^?7Xg>^%@x@5Ean5ULCV2vR>a}vsPb;4(7f7L95fMX1Qe!wMWH7Hx49l*E;TOkk z-N+l4TXq_-`CF6kBSA~mMUOSVdh)A33e|&e-~6U-_t|w<3x1U+_9d`!{k_j0hxBA$ z4=jT&19I53E^)QeqTcTM@m2=dkT=Zz$n&oP9i^{$bdf5kZ^*M_#{d(v5K&81HbU;= zM3G7*kgc+9PznGdV(Yn}oIV={4?kZYcU?Jpp%gNVy>%Cz3LWK+o3NP-Bq_egslQp+-EF_37HY_q;5bV;}Gj% zqGpjw9<`P8Ek+$q`{1d<`V&)JM+l3|2@8!DnXO*Ln1`uV9fD;84MxwNbSKgCa0zei zQ|8eMtdO|IY-^5UrxA~b76>l8TqI4iL?9@gH74jdO*7oOLCduR{$$2X71yW92V z{(N(PyY{}G4*jr5@4c^VP1@pKsjyX~!C|2RnXTc4iR!7$wwsylrnc#}>AtJ#5E zw##MPcD+1rpFX{O`uOtv>E+X>=jZ3``DNR-ecN-0BDAN*WLTB6T5fo^{fMRHt{}QPSpydTgW7~u}S_9vF|adkY!fQI08l4s@sh5Z3UyG3L^Gb$Y-;}&FOE=3RWbVNBxCd&o0 zXK%E0+-TseG?PQWIBsJeu97L3C~k_Z!f5dZ;7n)*gz z^wLFfeyAsH3^n2>T;!@i2udjhuk`glDP}_Li*&K$($}RQJ8nB3@0Ry(m*0GQ`}MbP zzWz_|PY-XGhqrRN!*W798-(ouY2iaOC09iwc2ac2lXhliIzX5hGy_YHm3AQex?!X4 zwmoxy*)N}Ld)}Xa-amfWKm79i^Upv3@Y4@}`r(iN{O3RX;pad8>G`i8_m3Mk^oxhQ zJfCQTnr^Bu+pbyVqndqRuAIk85b9_DIYFAYKaEZRMWe%#31MQ{&Bv#pW^p`20*sPm znozbpyxEY)%v%O|7kC=X_TW4*tFsTXGYOpA7%XXKKm@=^Kvew{KJFgH2n7+=g(b+a zR3WK;yub(`j2d~4&eUzAfx4J^G9c=!vz#7w&Ft|0(#wG8hzJ`#TT9OX;D>`H%@%nV z=S*n4bj)lcOt!plcIePi7JOwqm|g#g5Ia6jS1PZbSgLut$+e!#0nI09O%q7j@!2w! zq?_hD!|@Rx0EU%vb)nBFMc0_`uVVESJa6CrcBtzwuG2u_lOg*bzFwi|ym4CK3DTzp z7)DnPaI7MY(MyD37vzehM#+VOLmgiv?{w_NY?#EprLUKXTRAA(a`ZJ4|399LieY(u zJR&r%G!t!MByctuf|4g2G-ooy4FqfPlGoLAW1f9fVQXx_5v>> z+;fSw$lZVCdPcM2w9wyJqZ+T{P>68|LfoV!~gjI{?kAHk01a1!}jTg8``G(WjBRwmy4?SF)`l#5GGrcMVfje zps*02sv3tP!Nb#g#E60H8S3p3!zCU(DptD)?4}6+xrL_$vfFNvXe99wqBZ6aUvQ`r zjXnt&;b*_&s)vaw!`qz{RD$qY7-kF?iQ-e5B@(KZL@1%=ga!8$P2$huoHwwNY+gx> z5+fFs0F`0sep0tUt`hk^f4!F8ujSn&k|T$zviW@}=ye~g2x~0Vl+j_3=Op{5ro&fdaC-noQyL+* zyOtesuSpX|hH;8*|9*7GBxlm_>!UAGHqi*DzZH|P5{mds+NlaHukk5#5yry4M zFJ%U4GAY>b7AH^$4V{8LS@#vh2tp+RuO6L7fznN3nS@mnSrcwdxvrpjSy5_s!$`Zw`0& zINb4ggVxcGhEC}aFCbhpKG?$*NwGk9;SmoZ@0$if)RQtTWNze5nOqq_cqMKpcy}dL(Z5K7X zY`X2bshJn$`56bQ>DY1qL6C@K5?IC&VTfZYv=FLAlO>c&51e2gi=K1Y^VVW1c|&j| zY6mT={0AdA%146!JnS!}00D?szPX*CDG*^)vwinuCs%WBb;Q>+GH1*HhRV97l_V<;9z8ZaAt zG2Pe0E1@vozO*W9b`vxmXg@JSUYk0yNP4_;wMNeC?(B=1X=2(hymHFZJ527I5x6ao{>z}h0CvuQk@`>>W6axXop&u_u$+cO$dQ$?{h}6DaPeAWhuUQd%gu7QTW)TSclY;i@4x+#69C0dGVU-zht2>Ig!H1~Tr#HOz$UU@DQt1M)Ud^ zg9w-AO?0&ri%9Eej4YtTDr@JuaBaLjE|2Hs&BNj0@%;GqbaOu5KHjaz^Ljo>Kg#0H zQ5r?08_-*$bdctK&fUrjBm(cbA5#HgDll^rb|ZD)Y~Qf$x?NN+wk_M`WxHJLV0!MC zm*wSo{rG(P`0?e#hv$zUo!~}>qINY$iC{9em$bk$F1`)@g zj3c*fBA3uhQa2d#CYv^Z@HCkK(g2J=bHCXf(;cx6qba$xOq*p&Oe;iza3Q&Rbp8); z;7n=oBJ#M2ORG>`2fPIWZXiVY+CSw*gRf(d7M zuLL7ih-L!Q-b4hgL7>Yb$7SiJ=ibgIdAODT^qYtO^xfU}zk754{_E4zH|_Mm^^D~P z(hQAU$rdMFw+{9zOu!7k)yzp+&wX%#AZgk;Qc0;Q2eTdKM6zMOm~LnnS?Tnh9I* zSS0xbp2&uHBt+9Q8{`UQ_Bmw8bd1ffmFZ*rh4lm+bJXQ0zoN`X>rLlh|TR1P31 zFP1hlQeu0P$4S=*-Y-0SNbQLARt{esPw)G~ACGU}zWvR+um0bE|9|}L@Bi@6A3wd| zMX|MRsw7Gy&58}k(w?y>}t&Qn*6fKVe>4f3%@8WoYbxSpqX(IP6skEU7`8NTb)9jK&t3T z7JA4xf&dBhT!>0m0+E>u5t+*>WEb&>S#K1I(3Zw!kwyC2*=cK+)7o#2a(}bFd2@Wa zJ3T$!+&!Mp_vhoyt(?x%4$HEzEueU&hKR%=qZCT>EQ9P$Ee<0<@Af@)-3wP``2el8 z9cDDu?PB}I_O0)ieY@zk?JrwzYhPa0lOAvO6Z`BpFUkKo3^$s zU0#Hn>Scp83bi-B$R-;oBAaeItW!j`U13x;fr|9a)cjN?!uMvlZ?Vvj%bH=rVJf5{OIC zdJqsj$sAxe3FmxoUfd#!v~}faZRg(3E5CYo`0e*kzx&T$ef8~|+c!_^-8(tmaXGOa zAPb}kJq47cfPLDe+^et`!z-2u2taxet^`C*!rO666vqNC&q*$n6+GH< z8so5H1uT-L!VOcVSL{X6hB-kX`tcPYag)vrS)Wcz>mrWSiiUC=z|M3ea`*~*bD+4R z`DuJ~L)+qHK!(H2t6Uq*{4~vSoYugLZ(LNHSI<|AP&4Ei1_lE&XV8w{j|hwDY#|;SXwD?e8&1yNzX8S((&;`XO_~=2|-aeo}MbqLGPM892s} zX;$%m5^gnhH)9ue(_PWf7RcWEsU2>X)A4*d-QFz^Z~Nc=zy9^_{_e*gf2ML_V`IbS zOiU$SZ13U$SpgtMjZF-#>WO3;l{86dMX0H;om%O&U>xof0Ss~I41$5xyDP!1^lL%M zN;t+4DpPz2SN3}>ibN54ep7}#;{!IgieKt4kFc<`*$vtit5 zX4zNojB7DFVqU$j1nwkZEp(OCJ>qhv$|B$iBv!&RW=#e;rTi}W=pi31?fskk_4al> zpW5N5?I6pMghALk*`zfQX6YTl=BIWCC=xHAcSS6vKnw;pO2p|+qNJK;X3D;C-(Wko z7i^pCoAj5yZSvCkzV&50EbVey&!_d5(|Y{*<1asb{P^ibKB0?t)0a&J5H&H|+=&V0 z=G`onO)il>X8?AHnzlHho4YAs4c;o(GC(=%cRfoX<(<`LPMaS-kAREz%7{#@9*d#i zIL~AvK`YpF4GtzUorrzZ@q`gEJUgCDkyfGQQNh!#A~J7RPl}BDV)bHe4s$yl#g434 z{;k34h$GU799?pydk-8&JN3BlUqJ>9!{cwsKgIV}iAn|A03{b|2gm`-Q4Y_C!(lm{4>zat?fLfp_;~-vfBL)q-~ROR zr=PW>Z;c!H0+Xdvw~Oj-YLMva<&1xku)#x1*laHb-g;j^|fo=2Hk>%j6 z5}iUc)xbD2%Yb{9ypcZBSGfOkqjw^K()P1PNpyM*DN_Mmi3rpxCT<*0okuw-^_Uhz z`3~0?lCM)VOqe#y%qE~4nMEACuK_=c0LCn>P-=>XAyp(F7Z0i~uK4_V4oZAA-?3yi zEm203W{rO(J!_GCk7fhl^-iV;E3lA~_f^U9s~5+L%wU28YRg%&{-)~tz^ns3pXTEviqO;KuF||N5;BDOA=3RW*2ngxbVF4xbX3=J)U`g zlBb(?b85$f99LaCm!`|AZSDQg5BIm{H}6kxzgr%^>bLK4d_X^2JA)lu!4;&}^ku-T z<@Q-=ncB(N>7tNV{Q3Hi;9*0{sXDiKVVwbDN_H{neZ7sY8*Snkk?qU7v?d{{6 z^UcHg?(O4xzUjxCz8+MhwN+Xt05-N}Y-F4q5aB6$M^s6{pMZ&;YB>Bv!9PJkp&itA z0FcJMz_c%`?YnLlTQ0h7JRJ7ROM7`<`qulU_1>1{a6TMPhoAoXp)D_;KE3n?z6dXd z7Zu<#YK~El>x(Ju?d<26t>1$hWti5_hs`i1BbH8%GcqbY_br@&N0(jsfx`K4rL`X1SGQ z0?-p<4f>Cc=)N^bSQ(P3D4UFi-_BDu?lF8rr8uls@il2&L4|jxTZd{H)xU zyq%TkBNCp()v6j%KVE&&691=Lh2aU~^sEr_!~1EbNE^UhttnuUfuAvk;N$9sV%!BS zL^A-Z6F^34l0R3iGMUVvjs?xKv9CJC1o{spU_hJ(C;?p}QCne?5I0MFsL@O(*E7!Z ziri1{f)buvA(zH3TpG>`ZdX3sw1?aFc+01o_ITUQ2RW|2FVfq#tmsWB+78R%?*8=t zo5R!h{r)T7ys^U*`WZYzjsU>s`Teujt!P>H+m_nI;+LXW*r2rSs))|Mtfpcf7C@wnKIp)C3gTW`KCXlgR|8UA3sHi0s!K zB{@!~q%6Gxp(5U6pdHid*kXL zmsCa9LTb`e5ry~vytAHP9aHf2vgiuWafhg~HZU9exa4wcnTFm#potp*jA#JS8+wxt z>y4du=v)^$E^_F&JuP>K{`TGJt2f6tPv`T)?aiAv%gwE@LI#D{$R9M zF^cW`?g#W*KvGk4s-x!NtQ<PNAF_8^H_e=&7OBDaYYZTe*yo~BE zHD$T~O)1f!F7bUx16LLT^V~Lxl6Moq(kR5bw4>l=wa2^m^_%r~zkB%hx9{%XzBxX< z>$mq@Z_rPa1KL6^Ak%zp#)c%!%n0S@K+Q?nvPWT2^_5&eWlzc2+0r2aIy}O$8GzP^ z4(-NfvPxT*wzg&Ihjl%)(|Nr;|NbBU@y9>^`O}AOYfV0?VG+Y(`@SncKMUR{#2!{C zL&F+pdSdR1MaN86%HHtJ41nfd`z~wPJ{ALa-_R=pAgMuV31b3O6zM1Q$?{7=1h#tz zeFVxDhVPIX?qmyPl;&dMHjRcw?mrq~H_HXKE5c6H5(|^Uwm?oYssY~9FOrlpv_Cs> zocV`UkgMx)z;?56l;$sU2StUPAX(x3NcnT&~eJfJW35$ zH+GuK_a@!yk%8QS@TfRb8Q~|5X}M<-VF1%0d|`s$rs=ByXq>F#no4>@{Ob}pCi9Zn zEK_5Hdpyrp%pKTX+UuwARJ`49=kIr6?>uxaz|DfETY0*Xw-4*ngFM{e?%WTHEen@b zq=~dnYG0OqJoNj=GW-I5=sg47+duTL}@T4iQpyHvl~H_`bEn&GGb`!_gqWZ2#x~_}f3~v*&VkQVF`d zxKmZ4gmU*b9&|UWGpa-xy7M9pTD?uu*-S)LcR1Bh0n?&h%jsk`Q<(z6dlR%B(cbc! z*P1ccBJpkF@sU&PpGQ%!hJ!#|-D3@;{sAD^Q){TF~cq>+Fhw%Y({ z6rgwS!HM3`nC0@sp>YwB#zSjM$0Bm>{r1=%ZkDIVwY2|-XJqo>g;s}#^NOGvq?_$vV_#%hY`?H~ zwpIIz%cp(2te4)}(%Prids{xX-rA4&@bQzhO@J4Ws_dF#*22Z#KxpLhckCv4m=bMQ zHHc8vwTkD*$!V8a8%