Skip to content

Commit f5afa91

Browse files
committed
changed repeatable layer
1 parent 4bcffd1 commit f5afa91

2 files changed

Lines changed: 83 additions & 80 deletions

File tree

Lines changed: 82 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,121 @@
1-
# Copyright 2025 Lightricks Ltd.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
#
15-
# This implementation is based on the Torch version available at:
16-
# https://github.com/Lightricks/LTX-Video/tree/main
171
from dataclasses import field
182
from typing import Any, Callable, Dict, List, Tuple, Optional
193

204
import jax
215
from flax import linen as nn
6+
import jax.numpy as jnp
227
from flax.linen import partitioning
238

249

2510
class RepeatableCarryBlock(nn.Module):
26-
"""
27-
Integrates an input module in a jax carry format
11+
"""
12+
Integrates an input module in a jax carry format
2813
29-
ergo, the module assumes the role of a building block
30-
and returns both input and output across all blocks
31-
"""
14+
ergo, the module assumes the role of a building block
15+
and returns both input and output across all blocks
16+
"""
3217

33-
module: Callable[[Any], nn.Module]
34-
module_init_args: List[Any]
35-
module_init_kwargs: Dict[str, Any]
18+
module: Callable[[Any], nn.Module]
19+
module_init_args: List[Any]
20+
module_init_kwargs: Dict[str, Any]
3621

37-
@nn.compact
38-
def __call__(self, *args) -> Tuple[jax.Array, None]:
39-
"""
40-
jax carry-op format of block
41-
assumes the input contains an input tensor to the block along with kwargs that might be send to the block
42-
kwargs are assumed to have static role, while the input changes between cycles
22+
@nn.compact
23+
def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]:
24+
data_input, index_input = carry
4325

44-
Returns:
45-
Tuple[jax.Array, None]: Output tensor from the block
46-
"""
47-
mod = self.module(*self.module_init_args, **self.module_init_kwargs)
48-
output = mod(*args)
49-
return output, None
26+
mod = self.module(*self.module_init_args, **self.module_init_kwargs)
5027

28+
# block_args are the static arguments passed to each individual block
29+
output_data = mod(index_input, data_input, *block_args) # Pass block_args to the module
30+
31+
next_index = index_input + 1
32+
new_carry = (output_data, next_index)
33+
34+
35+
return new_carry, None
5136

5237
class RepeatableLayer(nn.Module):
53-
"""
54-
RepeatableLayer will assume a similar role to torch.nn.ModuleList
55-
with the condition that each block has the same graph, and only the parameters differ
38+
"""
39+
RepeatableLayer will assume a similar role to torch.nn.ModuleList
40+
with the condition that each block has the same graph, and only the parameters differ
5641
57-
The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation
58-
"""
42+
The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation
43+
"""
5944

60-
module: Callable[[Any], nn.Module]
61-
"""
45+
module: Callable[[Any], nn.Module]
46+
"""
6247
A Callable function for single block construction
6348
"""
6449

65-
num_layers: int
66-
"""
50+
num_layers: int
51+
"""
6752
The amount of blocks to build
6853
"""
6954

70-
module_init_args: List[Any] = field(default_factory=list)
71-
"""
55+
module_init_args: List[Any] = field(default_factory=list)
56+
"""
7257
args passed to RepeatableLayer.module callable, to support block construction
7358
"""
7459

75-
module_init_kwargs: Dict[str, Any] = field(default_factory=dict)
76-
"""
60+
module_init_kwargs: Dict[str, Any] = field(default_factory=dict)
61+
"""
7762
kwargs passed to RepeatableLayer.module callable, to support block construction
7863
"""
7964

80-
pspec_name: Optional[str] = None
81-
"""
65+
pspec_name: Optional[str] = None
66+
"""
8267
Partition spec metadata
8368
"""
8469

85-
param_scan_axis: int = 0
86-
"""
70+
param_scan_axis: int = 0
71+
"""
8772
The axis that the "layers" will be aggragated on
8873
eg: if a kernel is shaped (8, 16)
8974
N layers will be (N, 8, 16) if param_scan_axis=0
9075
and (8, N, 16) if param_scan_axis=1
9176
"""
9277

93-
@nn.compact
94-
def __call__(self, *args):
95-
96-
scan_kwargs = {}
97-
if self.pspec_name is not None:
98-
scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name}
99-
100-
initializing = self.is_mutable_collection("params")
101-
params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis)
102-
scan_fn = nn.scan(
103-
RepeatableCarryBlock,
104-
variable_axes={
105-
"params": params_spec,
106-
"cache": 0,
107-
"intermediates": 0,
108-
"aqt": 0,
109-
"_overwrite_with_gradient": 0,
110-
}, # Separate params per timestep
111-
split_rngs={"params": True},
112-
in_axes=(nn.broadcast,) * (len(args) - 1),
113-
length=self.num_layers,
114-
**scan_kwargs,
115-
)
116-
wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs)
117-
x, _ = wrapped_function(*args)
118-
return x
78+
@nn.compact
79+
def __call__(self, *args): # args is now the full input to RepeatableLayer
80+
if not args:
81+
raise ValueError("RepeatableLayer expects at least one argument for initial data input.")
82+
83+
initial_data_input = args[0] # The first element is your main data input
84+
static_block_args = args[1:] # Any subsequent elements are static args for each block
85+
86+
initial_index = jnp.array(0, dtype=jnp.int32)
87+
88+
scan_kwargs = {}
89+
if self.pspec_name is not None:
90+
scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name}
91+
92+
initializing = self.is_mutable_collection("params")
93+
params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis)
94+
95+
# in_axes for the scanned function (RepeatableCarryBlock.__call__):
96+
# 1. The 'carry' tuple ((0, 0))
97+
# 2. Then, nn.broadcast for each of the `static_block_args`
98+
in_axes_for_scan = (nn.broadcast,) * (len(args)-1)
99+
100+
scan_fn = nn.scan(
101+
RepeatableCarryBlock,
102+
variable_axes={
103+
"params": params_spec,
104+
"cache": 0,
105+
"intermediates": 0,
106+
"aqt": 0,
107+
"_overwrite_with_gradient": 0,
108+
},
109+
split_rngs={"params": True},
110+
in_axes=in_axes_for_scan,
111+
length=self.num_layers,
112+
**scan_kwargs,
113+
)
114+
115+
wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs)
116+
117+
# Call wrapped_function with the initial carry tuple and the static_block_args
118+
(final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args)
119+
120+
# Typically, you only want the final data output from the sequence of layers
121+
return final_data

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def load_transformer(cls, config):
228228
weights_init_fn = functools.partial(
229229
transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True
230230
)
231-
231+
import pdb; pdb.set_trace()
232232
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
233233

234234
checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)

0 commit comments

Comments
 (0)