Skip to content

Commit 13656fb

Browse files
committed
ltx-video-transformer-setup
1 parent 3776190 commit 13656fb

14 files changed

Lines changed: 2036 additions & 19 deletions

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright 2025 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
116
#hardware
217
hardware: 'tpu'
318
skip_jax_distributed_system: False

src/maxdiffusion/generate_ltx_video.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
118
from absl import app
219
from typing import Sequence
320
import jax
@@ -50,17 +67,7 @@ def run(config):
5067
text_tokens,
5168
num_tokens,
5269
features,
53-
eval_only=False
54-
)
55-
56-
transformer_state, transformer_state_shardings = setup_initial_state(
57-
model=transformer,
58-
tx=None,
59-
config=config,
60-
mesh=mesh,
61-
weights_init_fn=weights_init_fn,
62-
model_params=None,
63-
training=False,
70+
eval_only=True
6471
)
6572

6673

src/maxdiffusion/models/__init__.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525

2626
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2727

28-
from .controlnet_flax import FlaxControlNetModel
29-
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
30-
from .vae_flax import FlaxAutoencoderKL
31-
from .lora import *
32-
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33-
from .ltx_video.transformers.transformer3d import Transformer3DModel
28+
from .controlnet_flax import FlaxControlNetModel
29+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
30+
from .vae_flax import FlaxAutoencoderKL
31+
from .lora import *
32+
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33+
from .ltx_video.transformers.transformer3d import Transformer3DModel
3434

3535
else:
36-
import sys
36+
import sys
3737

38-
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
38+
sys.modules[__name__] = _LazyModule(
39+
__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

src/maxdiffusion/models/ltx_video/__init__.py

Whitespace-only changes.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from enum import Enum, auto
2+
from typing import Optional
3+
4+
import jax
5+
from flax import linen as nn
6+
7+
SKIP_GRADIENT_CHECKPOINT_KEY = "skip"
8+
9+
10+
class GradientCheckpointType(Enum):
11+
"""
12+
Defines the type of the gradient checkpoint we will have
13+
14+
NONE - means no gradient checkpoint
15+
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
16+
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
17+
except for ones that involve batch dimension - that means that all attention and projection
18+
layers will have gradient checkpoint, but not the backward with respect to the parameters
19+
"""
20+
21+
NONE = auto()
22+
FULL = auto()
23+
MATMUL_WITHOUT_BATCH = auto()
24+
25+
@classmethod
26+
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
27+
"""
28+
Constructs the gradient checkpoint type from a string
29+
30+
Args:
31+
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.
32+
33+
Returns:
34+
GradientCheckpointType: The policy that corresponds to the string
35+
"""
36+
if s is None:
37+
s = "none"
38+
return GradientCheckpointType[s.upper()]
39+
40+
def to_jax_policy(self):
41+
"""
42+
Converts the gradient checkpoint type to a jax policy
43+
"""
44+
match self:
45+
case GradientCheckpointType.NONE:
46+
return SKIP_GRADIENT_CHECKPOINT_KEY
47+
case GradientCheckpointType.FULL:
48+
return None
49+
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
50+
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
51+
52+
def apply(self, module: nn.Module) -> nn.Module:
53+
"""
54+
Applies a gradient checkpoint policy to a module
55+
if no policy is needed, it will return the module as is
56+
57+
Args:
58+
module (nn.Module): the module to apply the policy to
59+
60+
Returns:
61+
nn.Module: the module with the policy applied
62+
"""
63+
policy = self.to_jax_policy()
64+
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
65+
return module
66+
return nn.remat( # pylint: disable=invalid-name
67+
module,
68+
prevent_cse=False,
69+
policy=policy,
70+
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Union, Iterable, Tuple, Optional, Callable
2+
3+
import numpy as np
4+
import jax
5+
import jax.numpy as jnp
6+
from flax import linen as nn
7+
from flax.linen.initializers import lecun_normal
8+
9+
10+
Shape = Tuple[int, ...]
11+
Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array]
12+
InitializerAxis = Union[int, Shape]
13+
14+
15+
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
16+
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
17+
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
18+
19+
20+
def _canonicalize_tuple(x):
21+
if isinstance(x, Iterable):
22+
return tuple(x)
23+
else:
24+
return (x,)
25+
26+
27+
NdInitializer = Callable[[jax.random.PRNGKey, Shape,
28+
jnp.dtype, InitializerAxis, InitializerAxis], jax.Array]
29+
KernelInitializer = Callable[[jax.random.PRNGKey, Shape,
30+
jnp.dtype, InitializerAxis, InitializerAxis], jax.Array]
31+
32+
33+
class DenseGeneral(nn.Module):
34+
"""A linear transformation with flexible axes.
35+
36+
Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86
37+
38+
Attributes:
39+
features: tuple with numbers of output features.
40+
axis: tuple with axes to apply the transformation on.
41+
weight_dtype: the dtype of the weights (default: float32).
42+
dtype: the dtype of the computation (default: float32).
43+
kernel_init: initializer function for the weight matrix.
44+
use_bias: whether to add bias in linear transformation.
45+
bias_norm: whether to add normalization before adding bias.
46+
quant: quantization config, defaults to None implying no quantization.
47+
"""
48+
49+
features: Union[Iterable[int], int]
50+
axis: Union[Iterable[int], int] = -1
51+
weight_dtype: jnp.dtype = jnp.float32
52+
dtype: np.dtype = jnp.float32
53+
kernel_init: KernelInitializer = lecun_normal()
54+
kernel_axes: Tuple[Optional[str], ...] = ()
55+
use_bias: bool = False
56+
matmul_precision: str = "default"
57+
58+
bias_init: Initializer = jax.nn.initializers.constant(0.0)
59+
60+
@nn.compact
61+
def __call__(self, inputs: jax.Array) -> jax.Array:
62+
"""Applies a linear transformation to the inputs along multiple dimensions.
63+
64+
Args:
65+
inputs: The nd-array to be transformed.
66+
67+
Returns:
68+
The transformed input.
69+
"""
70+
71+
def compute_dot_general(inputs, kernel, axis, contract_ind):
72+
"""Computes a dot_general operation that may be quantized."""
73+
dot_general = jax.lax.dot_general
74+
matmul_precision = jax.lax.Precision(self.matmul_precision)
75+
return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision)
76+
77+
features = _canonicalize_tuple(self.features)
78+
axis = _canonicalize_tuple(self.axis)
79+
80+
inputs = jnp.asarray(inputs, self.dtype)
81+
axis = _normalize_axes(axis, inputs.ndim)
82+
83+
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
84+
kernel_in_axis = np.arange(len(axis))
85+
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
86+
kernel = self.param(
87+
"kernel",
88+
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
89+
kernel_shape,
90+
self.weight_dtype,
91+
)
92+
kernel = jnp.asarray(kernel, self.dtype)
93+
94+
contract_ind = tuple(range(0, len(axis)))
95+
output = compute_dot_general(inputs, kernel, axis, contract_ind)
96+
97+
if self.use_bias:
98+
bias_axes, bias_shape = (
99+
self.kernel_axes[-len(features):],
100+
kernel_shape[-len(features):],
101+
)
102+
bias = self.param(
103+
"bias",
104+
nn.with_logical_partitioning(self.bias_init, bias_axes),
105+
bias_shape,
106+
self.weight_dtype,
107+
)
108+
bias = jnp.asarray(bias, self.dtype)
109+
110+
output += bias
111+
return output
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from dataclasses import field
2+
from typing import Any, Callable, Dict, List, Tuple, Optional
3+
4+
import jax
5+
from flax import linen as nn
6+
from flax.linen import partitioning
7+
8+
9+
class RepeatableCarryBlock(nn.Module):
10+
"""
11+
Integrates an input module in a jax carry format
12+
13+
ergo, the module assumes the role of a building block
14+
and returns both input and output across all blocks
15+
"""
16+
17+
module: Callable[[Any], nn.Module]
18+
module_init_args: List[Any]
19+
module_init_kwargs: Dict[str, Any]
20+
21+
@nn.compact
22+
def __call__(self, *args) -> Tuple[jax.Array, None]:
23+
"""
24+
jax carry-op format of block
25+
assumes the input contains an input tensor to the block along with kwargs that might be send to the block
26+
kwargs are assumed to have static role, while the input changes between cycles
27+
28+
Returns:
29+
Tuple[jax.Array, None]: Output tensor from the block
30+
"""
31+
mod = self.module(*self.module_init_args, **self.module_init_kwargs)
32+
output = mod(*args)
33+
return output, None
34+
35+
36+
class RepeatableLayer(nn.Module):
37+
"""
38+
RepeatableLayer will assume a similar role to torch.nn.ModuleList
39+
with the condition that each block has the same graph, and only the parameters differ
40+
41+
The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation
42+
"""
43+
44+
module: Callable[[Any], nn.Module]
45+
"""
46+
A Callable function for single block construction
47+
"""
48+
49+
num_layers: int
50+
"""
51+
The amount of blocks to build
52+
"""
53+
54+
module_init_args: List[Any] = field(default_factory=list)
55+
"""
56+
args passed to RepeatableLayer.module callable, to support block construction
57+
"""
58+
59+
module_init_kwargs: Dict[str, Any] = field(default_factory=dict)
60+
"""
61+
kwargs passed to RepeatableLayer.module callable, to support block construction
62+
"""
63+
64+
pspec_name: Optional[str] = None
65+
"""
66+
Partition spec metadata
67+
"""
68+
69+
param_scan_axis: int = 0
70+
"""
71+
The axis that the "layers" will be aggragated on
72+
eg: if a kernel is shaped (8, 16)
73+
N layers will be (N, 8, 16) if param_scan_axis=0
74+
and (8, N, 16) if param_scan_axis=1
75+
"""
76+
77+
@nn.compact
78+
def __call__(self, *args):
79+
80+
scan_kwargs = {}
81+
if self.pspec_name is not None:
82+
scan_kwargs["metadata_params"] = {
83+
nn.PARTITION_NAME: self.pspec_name}
84+
85+
initializing = self.is_mutable_collection("params")
86+
params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(
87+
self.param_scan_axis)
88+
scan_fn = nn.scan(
89+
RepeatableCarryBlock,
90+
variable_axes={
91+
"params": params_spec,
92+
"cache": 0,
93+
"intermediates": 0,
94+
"aqt": 0,
95+
"_overwrite_with_gradient": 0,
96+
}, # Separate params per timestep
97+
split_rngs={"params": True},
98+
in_axes=(nn.broadcast,) * (len(args) - 1),
99+
length=self.num_layers,
100+
**scan_kwargs,
101+
)
102+
wrapped_function = scan_fn(
103+
self.module, self.module_init_args, self.module_init_kwargs)
104+
x, _ = wrapped_function(*args)
105+
return x

src/maxdiffusion/models/ltx_video/transformers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)