|
| 1 | +# Copyright 2025 Lightricks Ltd. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +# This implementation is based on the Torch version available at: |
| 16 | +# https://github.com/Lightricks/LTX-Video/tree/main |
1 | 17 | import argparse |
2 | 18 | import json |
3 | 19 | from typing import Any, Dict, Optional |
4 | 20 |
|
5 | 21 |
|
6 | | - |
7 | 22 | import jax |
8 | 23 | import jax.numpy as jnp |
9 | 24 | from flax.training import train_state |
10 | 25 | import optax |
11 | 26 | import orbax.checkpoint as ocp |
12 | 27 | from safetensors.torch import load_file |
13 | 28 | import requests |
14 | | -import shutil |
15 | 29 | from urllib.parse import urljoin |
16 | 30 |
|
17 | | -# from maxdiffusion.models.ltx_video.transformers_pytorch.transformer import Transformer3DModel |
18 | 31 | from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel |
19 | 32 | from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax |
20 | 33 |
|
21 | 34 | from huggingface_hub import hf_hub_download |
22 | 35 | import os |
23 | 36 | import importlib |
| 37 | + |
| 38 | + |
24 | 39 | def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import): |
25 | | - """ |
26 | | - Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module. |
| 40 | + """ |
| 41 | + Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module. |
| 42 | +
|
| 43 | + Args: |
| 44 | + github_base_url (str): The base URL of the GitHub repo. |
| 45 | + base_path (str): The base path where the new folder will be created. |
| 46 | + target_folder_name (str): The name of the folder to create. |
| 47 | + files_to_move (list): A list of file names to download and move. |
| 48 | + module_to_import (str): The full module path to import. |
| 49 | + """ |
27 | 50 |
|
28 | | - Args: |
29 | | - github_base_url (str): The base URL of the GitHub repo. |
30 | | - base_path (str): The base path where the new folder will be created. |
31 | | - target_folder_name (str): The name of the folder to create. |
32 | | - files_to_move (list): A list of file names to download and move. |
33 | | - module_to_import (str): The full module path to import. |
34 | | - """ |
| 51 | + target_path = os.path.join(base_path, target_folder_name) |
35 | 52 |
|
36 | | - target_path = os.path.join(base_path, target_folder_name) |
| 53 | + try: |
| 54 | + # Create the target directory |
| 55 | + os.makedirs(target_path, exist_ok=True) |
| 56 | + print(f"Created directory: {target_path}") |
37 | 57 |
|
| 58 | + # Download and move files |
| 59 | + for file_name in files_to_move: |
| 60 | + file_url = urljoin(github_base_url, file_name) |
| 61 | + destination_path = os.path.join(target_path, file_name) |
| 62 | + |
| 63 | + try: |
| 64 | + response = requests.get(file_url, stream=True) |
| 65 | + response.raise_for_status() |
| 66 | + |
| 67 | + with open(destination_path, "wb") as f: |
| 68 | + for chunk in response.iter_content(chunk_size=8192): |
| 69 | + f.write(chunk) |
| 70 | + |
| 71 | + print(f"Downloaded and moved: {file_name} -> {destination_path}") |
| 72 | + |
| 73 | + except requests.exceptions.RequestException as e: |
| 74 | + print(f"Error downloading {file_name}: {e}") |
| 75 | + except OSError as e: |
| 76 | + print(f"Error writing file {file_name}: {e}") |
| 77 | + print("Files downloaded and moved successfully.") |
| 78 | + |
| 79 | + # Verify that the folder exists |
| 80 | + if not os.path.exists(target_path): |
| 81 | + print(f"Error: Target folder {target_path} does not exist after files download.") |
| 82 | + # Dynamically import the module |
38 | 83 | try: |
39 | | - # Create the target directory |
40 | | - os.makedirs(target_path, exist_ok=True) |
41 | | - print(f"Created directory: {target_path}") |
42 | | - |
43 | | - # Download and move files |
44 | | - for file_name in files_to_move: |
45 | | - file_url = urljoin(github_base_url, file_name) |
46 | | - destination_path = os.path.join(target_path, file_name) |
47 | | - |
48 | | - try: |
49 | | - response = requests.get(file_url, stream=True) |
50 | | - response.raise_for_status() |
51 | | - |
52 | | - with open(destination_path, 'wb') as f: |
53 | | - for chunk in response.iter_content(chunk_size=8192): |
54 | | - f.write(chunk) |
55 | | - |
56 | | - print(f"Downloaded and moved: {file_name} -> {destination_path}") |
57 | | - |
58 | | - except requests.exceptions.RequestException as e: |
59 | | - print(f"Error downloading {file_name}: {e}") |
60 | | - return # Stop if there is an error. |
61 | | - except OSError as e: |
62 | | - print(f"Error writing file {file_name}: {e}") |
63 | | - return # Stop if there is an error. |
64 | | - print("Files downloaded and moved successfully.") |
65 | | - |
66 | | - # Verify that the folder exists |
67 | | - if not os.path.exists(target_path): |
68 | | - print(f"Error: Target folder {target_path} does not exist after files download.") |
69 | | - # Dynamically import the module |
70 | | - try: |
71 | | - imported_module = importlib.import_module(module_to_import) |
72 | | - print(f"Module '{module_to_import}' imported successfully.") |
73 | | - # Access the class |
74 | | - transformer_class = getattr(imported_module, "Transformer3DModel") |
75 | | - print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}") |
76 | | - return transformer_class |
77 | | - except ImportError as e: |
78 | | - print(f"Error importing module '{module_to_import}': {e}") |
79 | | - except AttributeError as e: |
80 | | - print(f"Error accessing class 'Transformer3DModel': {e}") |
81 | | - |
82 | | - except OSError as e: |
83 | | - print(f"Error during file system operation: {e}") |
84 | | - except Exception as e: |
85 | | - print(f"An unexpected error occurred: {e}") |
| 84 | + imported_module = importlib.import_module(module_to_import) |
| 85 | + print(f"Module '{module_to_import}' imported successfully.") |
| 86 | + # Access the class |
| 87 | + transformer_class = getattr(imported_module, "Transformer3DModel") |
| 88 | + print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}") |
| 89 | + return transformer_class |
| 90 | + except ImportError as e: |
| 91 | + print(f"Error importing module '{module_to_import}': {e}") |
| 92 | + except AttributeError as e: |
| 93 | + print(f"Error accessing class 'Transformer3DModel': {e}") |
| 94 | + |
| 95 | + except OSError as e: |
| 96 | + print(f"Error during file system operation: {e}") |
| 97 | + except Exception as e: |
| 98 | + print(f"An unexpected error occurred: {e}") |
| 99 | + |
| 100 | + |
86 | 101 | class Checkpointer: |
87 | 102 | """ |
88 | 103 | Checkpointer - to load and store JAX checkpoints |
@@ -204,13 +219,13 @@ def main(args): |
204 | 219 | ) |
205 | 220 | print("Downloading files from GitHub...") |
206 | 221 | github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/" |
207 | | - ltx_repo_path = "../" |
| 222 | + ltx_repo_path = "../" |
208 | 223 | target_folder = "transformers_pytorch" |
209 | 224 | files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"] |
210 | 225 | module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d" |
211 | 226 |
|
212 | 227 | Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path) |
213 | | - |
| 228 | + |
214 | 229 | print("Loading safetensors, flush = True") |
215 | 230 | weight_file = "ltxv-13b-0.9.7-dev.safetensors" |
216 | 231 |
|
|
0 commit comments