Skip to content

Commit 8bf24a3

Browse files
committed
headers
1 parent f6115df commit 8bf24a3

11 files changed

Lines changed: 162 additions & 131 deletions

src/maxdiffusion/generate_ltx_video.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,3 @@ def main(argv: Sequence[str]) -> None:
112112

113113
if __name__ == "__main__":
114114
app.run(main)
115-
116-
117-

src/maxdiffusion/max_utils.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True):
271271
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272272

273273
return mesh
274-
274+
275275
try:
276276
num_slices = 1 + max([d.slice_index for d in devices])
277277
except:
@@ -303,66 +303,9 @@ def create_device_mesh(config, devices=None, logging=True):
303303
if logging:
304304
max_logging.log(f"Decided on mesh: {mesh}")
305305

306-
307-
308-
309-
310-
311-
312-
313-
314-
315-
316-
317-
318-
319-
320-
321-
322-
323-
324306
return mesh
325307

326308

327-
328-
329-
330-
331-
332-
333-
334-
335-
336-
337-
338-
339-
340-
341-
342-
343-
344-
345-
346-
347-
348-
349-
350-
351-
352-
353-
354-
355-
356-
357-
358-
359-
360-
361-
362-
363-
364-
365-
366309
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
367310
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
368311
@@ -685,4 +628,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
685628
initialize_jax_for_gpu()
686629
max_logging.log("Jax distributed system initialized on GPU!")
687630
else:
688-
jax.distributed.initialize()
631+
jax.distributed.initialize()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2025 Lightricks Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This implementation is based on the Torch version available at:
16+
# https://github.com/Lightricks/LTX-Video/tree/main
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def __call__(
452452
deterministic: bool = True,
453453
**cross_attention_kwargs,
454454
) -> jnp.ndarray:
455-
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821
455+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa: F821
456456
assert cross_attention_kwargs.get("scale", None) is None, "Not supported"
457457

458458
input_axis_names = ("activation_batch", "activation_length", "activation_embed")

src/maxdiffusion/models/ltx_video/transformers/transformer3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def scale_shift_table_init(key):
112112
self.transformer_blocks = RepeatableLayer(
113113
RemattedBasicTransformerBlock,
114114
num_layers=self.num_layers,
115-
module_init_kwargs=dict( #noqa: C408
115+
module_init_kwargs=dict( # noqa: C408
116116
dim=self.inner_dim,
117117
num_attention_heads=self.num_attention_heads,
118118
attention_head_dim=self.attention_head_dim,
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
### Transformer Pytorch Weight Downloading and Jax Weight Loading Instructions:
2-
1. Create new tansformers_pytorch folder under models/ltx_video.
3-
2. Move files from LTX repo, specifically, attention.py, embeddings.py, symmetric_patchifier.py, and transformer3d.py into the newly created folder. See here: https://github.com/Lightricks/LTX-Video/tree/main/ltx_video/models/transformers
4-
3. Rename transformer3d.py to transformer_pt.py to distinguish from the pytorch version. Change classname to Transformer3DModel_PT. Also change classname in line "transformer = Transformer3DModel.from_config(transformer_config)"
5-
4. Weight Downloading and Conversion
2+
1. Weight Downloading and Conversion
63
- If first time running (no local safetensors): \
74
In the src/maxdiffusion/models/ltx_video/utils folder, run python convert_torch_weights_to_jax.py --download_ckpt_path [location to download safetensors] --output_dir [location to save jax ckpt] --transformer_config_path ../xora_v1.2-13B-balanced-128.json.
85
- If already have local pytorch checkpoint: \
96
Replace the --download_ckpt_path with --local_ckpt_path and add corresponding location
10-
5. Restoring Jax Weights into transformer:
7+
2. Restoring Jax Weights into transformer:
118
- Replace the "ckpt_path" in src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json with jax ckpt path.
129
- Run python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml in the outer repo folder.
1310

src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,103 @@
1+
# Copyright 2025 Lightricks Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This implementation is based on the Torch version available at:
16+
# https://github.com/Lightricks/LTX-Video/tree/main
117
import argparse
218
import json
319
from typing import Any, Dict, Optional
420

521

6-
722
import jax
823
import jax.numpy as jnp
924
from flax.training import train_state
1025
import optax
1126
import orbax.checkpoint as ocp
1227
from safetensors.torch import load_file
1328
import requests
14-
import shutil
1529
from urllib.parse import urljoin
1630

17-
# from maxdiffusion.models.ltx_video.transformers_pytorch.transformer import Transformer3DModel
1831
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel
1932
from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax
2033

2134
from huggingface_hub import hf_hub_download
2235
import os
2336
import importlib
37+
38+
2439
def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import):
25-
"""
26-
Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module.
40+
"""
41+
Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module.
42+
43+
Args:
44+
github_base_url (str): The base URL of the GitHub repo.
45+
base_path (str): The base path where the new folder will be created.
46+
target_folder_name (str): The name of the folder to create.
47+
files_to_move (list): A list of file names to download and move.
48+
module_to_import (str): The full module path to import.
49+
"""
2750

28-
Args:
29-
github_base_url (str): The base URL of the GitHub repo.
30-
base_path (str): The base path where the new folder will be created.
31-
target_folder_name (str): The name of the folder to create.
32-
files_to_move (list): A list of file names to download and move.
33-
module_to_import (str): The full module path to import.
34-
"""
51+
target_path = os.path.join(base_path, target_folder_name)
3552

36-
target_path = os.path.join(base_path, target_folder_name)
53+
try:
54+
# Create the target directory
55+
os.makedirs(target_path, exist_ok=True)
56+
print(f"Created directory: {target_path}")
3757

58+
# Download and move files
59+
for file_name in files_to_move:
60+
file_url = urljoin(github_base_url, file_name)
61+
destination_path = os.path.join(target_path, file_name)
62+
63+
try:
64+
response = requests.get(file_url, stream=True)
65+
response.raise_for_status()
66+
67+
with open(destination_path, "wb") as f:
68+
for chunk in response.iter_content(chunk_size=8192):
69+
f.write(chunk)
70+
71+
print(f"Downloaded and moved: {file_name} -> {destination_path}")
72+
73+
except requests.exceptions.RequestException as e:
74+
print(f"Error downloading {file_name}: {e}")
75+
except OSError as e:
76+
print(f"Error writing file {file_name}: {e}")
77+
print("Files downloaded and moved successfully.")
78+
79+
# Verify that the folder exists
80+
if not os.path.exists(target_path):
81+
print(f"Error: Target folder {target_path} does not exist after files download.")
82+
# Dynamically import the module
3883
try:
39-
# Create the target directory
40-
os.makedirs(target_path, exist_ok=True)
41-
print(f"Created directory: {target_path}")
42-
43-
# Download and move files
44-
for file_name in files_to_move:
45-
file_url = urljoin(github_base_url, file_name)
46-
destination_path = os.path.join(target_path, file_name)
47-
48-
try:
49-
response = requests.get(file_url, stream=True)
50-
response.raise_for_status()
51-
52-
with open(destination_path, 'wb') as f:
53-
for chunk in response.iter_content(chunk_size=8192):
54-
f.write(chunk)
55-
56-
print(f"Downloaded and moved: {file_name} -> {destination_path}")
57-
58-
except requests.exceptions.RequestException as e:
59-
print(f"Error downloading {file_name}: {e}")
60-
return # Stop if there is an error.
61-
except OSError as e:
62-
print(f"Error writing file {file_name}: {e}")
63-
return # Stop if there is an error.
64-
print("Files downloaded and moved successfully.")
65-
66-
# Verify that the folder exists
67-
if not os.path.exists(target_path):
68-
print(f"Error: Target folder {target_path} does not exist after files download.")
69-
# Dynamically import the module
70-
try:
71-
imported_module = importlib.import_module(module_to_import)
72-
print(f"Module '{module_to_import}' imported successfully.")
73-
# Access the class
74-
transformer_class = getattr(imported_module, "Transformer3DModel")
75-
print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}")
76-
return transformer_class
77-
except ImportError as e:
78-
print(f"Error importing module '{module_to_import}': {e}")
79-
except AttributeError as e:
80-
print(f"Error accessing class 'Transformer3DModel': {e}")
81-
82-
except OSError as e:
83-
print(f"Error during file system operation: {e}")
84-
except Exception as e:
85-
print(f"An unexpected error occurred: {e}")
84+
imported_module = importlib.import_module(module_to_import)
85+
print(f"Module '{module_to_import}' imported successfully.")
86+
# Access the class
87+
transformer_class = getattr(imported_module, "Transformer3DModel")
88+
print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}")
89+
return transformer_class
90+
except ImportError as e:
91+
print(f"Error importing module '{module_to_import}': {e}")
92+
except AttributeError as e:
93+
print(f"Error accessing class 'Transformer3DModel': {e}")
94+
95+
except OSError as e:
96+
print(f"Error during file system operation: {e}")
97+
except Exception as e:
98+
print(f"An unexpected error occurred: {e}")
99+
100+
86101
class Checkpointer:
87102
"""
88103
Checkpointer - to load and store JAX checkpoints
@@ -204,13 +219,13 @@ def main(args):
204219
)
205220
print("Downloading files from GitHub...")
206221
github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/"
207-
ltx_repo_path = "../"
222+
ltx_repo_path = "../"
208223
target_folder = "transformers_pytorch"
209224
files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"]
210225
module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d"
211226

212227
Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path)
213-
228+
214229
print("Loading safetensors, flush = True")
215230
weight_file = "ltxv-13b-0.9.7-dev.safetensors"
216231

src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 Lightricks Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This implementation is based on the Torch version available at:
16+
# https://github.com/Lightricks/LTX-Video/tree/main
117
def make_hashable_key(dict_key):
218
def convert_value(value):
319
if isinstance(value, list):

src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 Lightricks Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This implementation is based on the Torch version available at:
16+
# https://github.com/Lightricks/LTX-Video/tree/main
117
from enum import Enum, auto
218

319

0 commit comments

Comments
 (0)